"third_party/vscode:/vscode.git/clone" did not exist on "1e3cc6b2bdb0d3700e03d5f187b0bc0d07cb1b85"
Unverified Commit 06971ac4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Bart] Refactor - fix issues, consistency with the library, naming (#8900)

* remove make on the fly linear embedding

* start refactor

* big first refactor

* save intermediate

* save intermediat

* correct mask issue

* save tests

* refactor padding masks

* make all tests pass

* further refactor

* make pegasus test pass

* fix bool if

* fix leftover tests

* continue

* bart renaming

* delete torchscript test hack

* fix imports in tests

* correct shift

* fix docs and repo cons

* re-add fix for FSTM

* typo in test

* fix typo

* fix another typo

* continue

* hot fix 2 for tf

* small fixes

* refactor types linting

* continue

* finish refactor

* fix import in tests

* better bart names

* further refactor and add test

* delete hack

* apply sylvains and lysandres commens

* small perf improv

* further perf improv

* improv perf

* fix typo

* make style

* small perf improv
parent 75627148
...@@ -105,8 +105,6 @@ BartModel ...@@ -105,8 +105,6 @@ BartModel
.. autoclass:: transformers.BartModel .. autoclass:: transformers.BartModel
:members: forward :members: forward
.. autofunction:: transformers.models.bart.modeling_bart._prepare_bart_decoder_inputs
BartForConditionalGeneration BartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -378,6 +378,7 @@ if is_torch_available(): ...@@ -378,6 +378,7 @@ if is_torch_available():
BartForQuestionAnswering, BartForQuestionAnswering,
BartForSequenceClassification, BartForSequenceClassification,
BartModel, BartModel,
BartPretrainedModel,
PretrainedBartModel, PretrainedBartModel,
) )
from .models.bert import ( from .models.bert import (
......
...@@ -31,6 +31,7 @@ if is_torch_available(): ...@@ -31,6 +31,7 @@ if is_torch_available():
BartForQuestionAnswering, BartForQuestionAnswering,
BartForSequenceClassification, BartForSequenceClassification,
BartModel, BartModel,
BartPretrainedModel,
PretrainedBartModel, PretrainedBartModel,
) )
......
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
"""PyTorch BART model, ported from the fairseq repo.""" """PyTorch BART model, ported from the fairseq repo."""
import math import math
import random import random
from typing import Dict, List, Optional, Tuple import warnings
from typing import Dict, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
...@@ -61,8 +62,479 @@ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -61,8 +62,479 @@ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
# This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart # This list is incomplete. See all BART models at https://huggingface.co/models?filter=bart
BART_START_DOCSTRING = r""" def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
"""
Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).
"""
prev_output_tokens = input_ids.clone()
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = input_ids[:, :-1]
return prev_output_tokens
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None, past_key_values_length: int = 0
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
if past_key_values_length > 0:
# concat fully attendend attention_mask to the beginning if `past_key_values` are used
expanded_mask = torch.cat(
[
torch.ones(bsz, 1, tgt_len, past_key_values_length, device=expanded_mask.device, dtype=dtype),
expanded_mask,
],
dim=-1,
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True):
if torch.cuda.is_available():
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class BartLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset: int):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = offset
assert padding_idx is not None, "`padding_idx` should not be None, but of type int"
num_embeddings += offset
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions + self.offset)
class BartSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter):
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
n_pos, dim = out.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
return out
@torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions)
class BartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
assert attn_weights.size() == (
bsz * self.num_heads,
tgt_len,
src_len,
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
if attn_mask is not None:
assert attn_mask.size() == (
bsz,
1,
tgt_len,
src_len,
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attn_mask.size()}"
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
assert attn_output.size() == (
bsz * self.num_heads,
tgt_len,
self.head_dim,
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
attn_output = (
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.reshape(bsz, tgt_len, embed_dim)
)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class BartEncoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BartAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim)
def forward(
self, hidden_states: torch.Tensor, encoder_padding_mask: torch.Tensor, output_attentions: bool = False
):
"""
Args:
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = hidden_states
if self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attn_mask=encoder_padding_mask, output_attentions=output_attentions
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if not self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
if self.normalize_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if not self.normalize_before:
hidden_states = self.final_layer_norm(hidden_states)
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
return hidden_states, attn_weights
class BartDecoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BartAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.encoder_attn = BartAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.encoder_attn_layer_norm = BartLayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = BartLayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attn_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attn_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[torch.Tensor] = False,
):
residual = hidden_states
if self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at first position
self_attn_past_key_value = past_key_value[0] if past_key_value is not None else None
hidden_states, self_attn_weights, self_attn_present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attn_mask=attn_mask,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if not self.normalize_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
if self.normalize_before:
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at second position
cross_attn_past_key_value = past_key_value[1] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attn_mask=encoder_attn_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if not self.normalize_before:
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
if self.normalize_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if not self.normalize_before:
hidden_states = self.final_layer_norm(hidden_states)
# make sure decoder uni-directional self-attn at 1st position and cross-attn at 2nd position.
present_key_value = (self_attn_present_key_value, cross_attn_present_key_value)
return (
hidden_states,
self_attn_weights,
present_key_value,
cross_attn_weights,
)
class BartClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim: int,
inner_dim: int,
num_classes: int,
pooler_dropout: float,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class BartPretrainedModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, BartSinusoidalPositionalEmbedding):
pass
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
}
return dummy_inputs
class PretrainedBartModel(BartPretrainedModel):
def __init_subclass__(self):
warnings.warn(
"The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.",
FutureWarning,
)
BART_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
pruning heads etc.) pruning heads etc.)
...@@ -76,7 +548,6 @@ BART_START_DOCSTRING = r""" ...@@ -76,7 +548,6 @@ BART_START_DOCSTRING = r"""
Initializing with a config file does not load the weights associated with the model, only the Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights. weights.
""" """
BART_GENERATION_EXAMPLE = r""" BART_GENERATION_EXAMPLE = r"""
...@@ -94,7 +565,6 @@ BART_GENERATION_EXAMPLE = r""" ...@@ -94,7 +565,6 @@ BART_GENERATION_EXAMPLE = r"""
>>> # Generate Summary >>> # Generate Summary
>>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
""" """
BART_INPUTS_DOCSTRING = r""" BART_INPUTS_DOCSTRING = r"""
...@@ -118,7 +588,7 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -118,7 +588,7 @@ BART_INPUTS_DOCSTRING = r"""
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Provide for translation and summarization training. By default, the model will create this tensor by Provide for translation and summarization training. By default, the model will create this tensor by
shifting the :obj:`input_ids` to the right, following the paper. shifting the :obj:`input_ids` to the right, following the paper.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`): decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default. also be used by default.
...@@ -130,12 +600,24 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -130,12 +600,24 @@ BART_INPUTS_DOCSTRING = r"""
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
`optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder. cross-attention of the decoder.
past_key_values (:obj:`Tuple[Dict[str: tf.Tensor]]` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids`` If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
:obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
takes the value of :obj:`inputs_embeds`.
use_cache (:obj:`bool`, `optional`): use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`). decoding (see :obj:`past_key_values`).
...@@ -150,461 +632,350 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -150,461 +632,350 @@ BART_INPUTS_DOCSTRING = r"""
""" """
def invert_mask(attention_mask): class BartEncoder(BartPretrainedModel):
"""Turns 1->0, 0->1, False->True, True-> False"""
assert attention_mask.dim() == 2
return attention_mask.eq(0)
def _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=torch.float32
):
"""
Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided.
This mimics the default behavior in fairseq. To override it pass in masks. Note: this is not called during
generation
"""
pad_token_id = config.pad_token_id
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
bsz, tgt_len = decoder_input_ids.size()
if decoder_padding_mask is None:
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(decoder_padding_mask)
if decoder_padding_mask is not None and decoder_padding_mask.shape[1] > 1:
# never mask leading token, even if it is pad
decoder_padding_mask[:, 0] = decoder_padding_mask[:, 1]
tmp = fill_with_neg_inf(torch.zeros(tgt_len, tgt_len))
mask = torch.arange(tmp.size(-1))
tmp.masked_fill_(mask < (mask + 1).view(tmp.size(-1), 1), 0)
causal_mask = tmp.to(dtype=causal_mask_dtype, device=decoder_input_ids.device)
return decoder_input_ids, decoder_padding_mask, causal_mask
class PretrainedBartModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, SinusoidalPositionalEmbedding):
pass
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
}
return dummy_inputs
def _make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
def shift_tokens_right(input_ids, pad_token_id):
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
prev_output_tokens = input_ids.clone()
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = input_ids[:, :-1]
return prev_output_tokens
def make_padding_mask(input_ids, padding_idx=1):
"""True for pad tokens"""
padding_mask = input_ids.eq(padding_idx)
if not padding_mask.any():
padding_mask = None
return padding_mask
# Helper Modules
class EncoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(self, x, encoder_padding_mask, output_attentions=False):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
for t_tgt, t_src is excluded (or masked out), =0 means it is
included in attention
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, attn_weights = self.self_attn(
query=x, key=x, key_padding_mask=encoder_padding_mask, output_attentions=output_attentions
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
if torch.isinf(x).any() or torch.isnan(x).any():
clamp_value = torch.finfo(x.dtype).max - 1000
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
return x, attn_weights
class BartEncoder(nn.Module):
""" """
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
:class:`EncoderLayer`. :class:`BartEncoderLayer`.
Args: Args:
config: BartConfig config: BartConfig
embed_tokens (torch.nn.Embedding): output embedding
""" """
def __init__(self, config: BartConfig, embed_tokens): def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
super().__init__() super().__init__(config)
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop self.layerdrop = config.encoder_layerdrop
embed_dim = embed_tokens.embedding_dim embed_dim = config.d_model
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.padding_idx = embed_tokens.padding_idx self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_tokens = embed_tokens if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
if config.static_position_embeddings: if config.static_position_embeddings:
self.embed_positions = SinusoidalPositionalEmbedding( self.embed_positions = BartSinusoidalPositionalEmbedding(
config.max_position_embeddings, embed_dim, self.padding_idx config.max_position_embeddings, embed_dim, self.padding_idx
) )
else: else:
self.embed_positions = LearnedPositionalEmbedding( self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
embed_dim, embed_dim,
self.padding_idx, self.padding_idx,
config.extra_pos_embeddings, config.extra_pos_embeddings,
) )
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() self.layernorm_embedding = BartLayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
# mbart has one extra layer_norm # mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None
self.init_weights()
def forward( def forward(
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
): ):
""" r"""
Args: Args:
input_ids (LongTensor): tokens in the source language of shape input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
`(batch, src_len)` Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
attention_mask (torch.LongTensor): indicating which indices are padding tokens provide it.
Indices can be obtained using :class:`~transformers.BartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
into associated vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Returns: # retrieve input_ids and inputs_embeds
BaseModelOutput or Tuple comprised of: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
- **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` if inputs_embeds is None:
- **encoder_states** (tuple(torch.FloatTensor)): all intermediate hidden states of shape `(src_len, inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
batch, embed_dim)`. Only populated if *output_hidden_states:* is True.
- **all_attentions** (tuple(torch.FloatTensor)): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
# check attention mask and invert
if attention_mask is not None:
attention_mask = invert_mask(attention_mask)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input_shape)
embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C hidden_states = inputs_embeds + embed_pos
x = x.transpose(0, 1) hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
for encoder_layer in self.layers: for encoder_layer in self.layers:
if output_hidden_states: if output_hidden_states:
x = x.transpose(0, 1) # T x B x C -> B x T x C encoder_states = encoder_states + (hidden_states,)
encoder_states = encoder_states + (x,)
x = x.transpose(0, 1) # B x T x C -> T x B x C
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
attn = None attn = None
else: else:
x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions) hidden_states, attn = encoder_layer(hidden_states, attention_mask, output_attentions=output_attentions)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (attn,) all_attentions = all_attentions + (attn,)
if self.layer_norm: if self.layer_norm:
x = self.layer_norm(x) hidden_states = self.layer_norm(hidden_states)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (x,) encoder_states = encoder_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [x, encoder_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions) return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
class DecoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = Attention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn = Attention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
encoder_decoder_attention=True,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(
self,
x,
encoder_hidden_states,
encoder_attn_mask=None,
layer_state=None,
causal_mask=None,
decoder_padding_mask=None,
output_attentions=False,
):
residual = x
if layer_state is None:
layer_state = {}
if self.normalize_before:
x = self.self_attn_layer_norm(x)
# Self Attention
x, self_attn_weights = self.self_attn(
query=x,
key=x,
layer_state=layer_state, # adds keys to layer state
key_padding_mask=decoder_padding_mask,
attn_mask=causal_mask,
output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
# Cross-Attention Block
residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x, cross_attn_weights = self.encoder_attn(
query=x,
key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
output_attentions=output_attentions,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
# Fully Connected
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
return (
x,
self_attn_weights,
layer_state,
cross_attn_weights,
) # layer_state = cache for decoding
class BartDecoder(nn.Module): class BartDecoder(BartPretrainedModel):
""" """
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`DecoderLayer` Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`BartDecoderLayer`
Args: Args:
config: BartConfig config: BartConfig
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (torch.nn.Embedding): output embedding
""" """
def __init__(self, config: BartConfig, embed_tokens: nn.Embedding): def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
super().__init__() super().__init__(config)
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm # layernorm variant self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm # layernorm variant
self.padding_idx = embed_tokens.padding_idx self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = embed_tokens
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
if config.static_position_embeddings: if config.static_position_embeddings:
self.embed_positions = SinusoidalPositionalEmbedding( self.embed_positions = BartSinusoidalPositionalEmbedding(
config.max_position_embeddings, config.d_model, config.pad_token_id config.max_position_embeddings, config.d_model, config.pad_token_id
) )
else: else:
self.embed_positions = LearnedPositionalEmbedding( self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
config.d_model, config.d_model,
self.padding_idx, self.padding_idx,
config.extra_pos_embeddings, config.extra_pos_embeddings,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
[DecoderLayer(config) for _ in range(config.decoder_layers)] self.layernorm_embedding = BartLayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
) # type: List[DecoderLayer] self.layer_norm = BartLayerNorm(config.d_model) if config.add_final_layer_norm else None
self.layernorm_embedding = LayerNorm(config.d_model) if config.normalize_embedding else nn.Identity()
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None self.init_weights()
def forward( def forward(
self, self,
input_ids, input_ids=None,
encoder_hidden_states, attention_mask=None,
encoder_padding_mask, encoder_hidden_states=None,
decoder_padding_mask, encoder_attention_mask=None,
decoder_causal_mask,
past_key_values=None, past_key_values=None,
use_cache=False, inputs_embeds=None,
output_attentions=False, use_cache=None,
output_hidden_states=False, output_attentions=None,
return_dict=True, output_hidden_states=None,
return_dict=None,
): ):
""" r"""
Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al.,
EMNLP 2019).
Args: Args:
input_ids (LongTensor): previous decoder outputs of shape input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
`(batch, tgt_len)`, for teacher forcing Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
encoder_hidden_states: output from the encoder, used for provide it.
encoder-side attention
encoder_padding_mask: for ignoring pad tokens Indices can be obtained using :class:`~transformers.BartTokenizer`. See
past_key_values (dict or None): dictionary used for storing state during generation :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
Returns:
BaseModelOutputWithPast or tuple: `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
- the decoder's features of shape `(batch, tgt_len, embed_dim)` Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- the cache
- hidden states - 1 for tokens that are **not masked**,
- attentions - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`):
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last
:obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of
shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size,
sequence_length)`.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
into associated vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attn_mask = None
if input_shape[-1] > 1:
attn_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
# create decoder_padding_mask if not provided and needed
# 4.12.20 (PVP): Not a fan of this "magical" function that
# automatically creates attention_mask for padded tokens
# => this is inconsistent with other models
# => Pegasus uses the pad_token as decoder_start_token_id, so that this could
# pose some problems.
if (
attention_mask is None
and input_ids is not None
and input_shape[-1] > 1
and self.config.pad_token_id in input_ids
):
# should be kept for backwards compatibility
attention_mask = input_ids.ne(self.config.pad_token_id).to(torch.long)
# never mask leading token, even if it is pad
attention_mask[:, 0] = attention_mask[:, 1]
if attention_mask is not None and attn_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attn_mask = attn_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, past_key_values_length=past_key_values_length
)
# check attention mask and invert # expand encoder attention mask
if encoder_padding_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
encoder_padding_mask = invert_mask(encoder_padding_mask) # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# embed positions # embed positions
positions = self.embed_positions(input_ids, use_cache=use_cache) positions = self.embed_positions(input_shape, past_key_values_length)
if use_cache:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:]
x = self.embed_tokens(input_ids) * self.embed_scale
if self.do_blenderbot_90_layernorm: if self.do_blenderbot_90_layernorm:
x = self.layernorm_embedding(x) hidden_states = self.layernorm_embedding(inputs_embeds)
x += positions hidden_states += positions
else: else:
x += positions hidden_states = inputs_embeds + positions
x = self.layernorm_embedding(x) hidden_states = self.layernorm_embedding(hidden_states)
x = F.dropout(x, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
# Convert to Bart output format: (BS, seq_len, model_dim) -> (seq_len, BS, model_dim)
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions else None
next_decoder_cache: List[Dict] = [] next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states: if output_hidden_states:
x = x.transpose(0, 1) all_hidden_states += (hidden_states,)
all_hidden_states += (x,)
x = x.transpose(0, 1)
dropout_probability = random.uniform(0, 1) dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): if self.training and (dropout_probability < self.layerdrop):
continue continue
layer_state = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer( hidden_states, layer_self_attn, present_key_value, layer_cross_attn = decoder_layer(
x, hidden_states,
encoder_hidden_states, encoder_hidden_states,
encoder_attn_mask=encoder_padding_mask, encoder_attn_mask=encoder_attention_mask,
decoder_padding_mask=decoder_padding_mask, attn_mask=attn_mask,
layer_state=layer_state, past_key_value=past_key_value,
causal_mask=decoder_causal_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
if use_cache: if use_cache:
next_decoder_cache.append(layer_past.copy()) next_decoder_cache += (present_key_value,)
if output_attentions: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
...@@ -612,24 +983,21 @@ class BartDecoder(nn.Module): ...@@ -612,24 +983,21 @@ class BartDecoder(nn.Module):
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
x = x.transpose(0, 1) all_hidden_states += (hidden_states,)
all_hidden_states += (x,)
x = x.transpose(0, 1)
if self.layer_norm: # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) # if config.add_final_layer_norm (mBART)
x = x.transpose(0, 1) if self.layer_norm:
encoder_hidden_states = encoder_hidden_states.transpose(0, 1) hidden_states = self.layer_norm(hidden_states)
next_cache = next_decoder_cache if use_cache else None next_cache = next_decoder_cache if use_cache else None
if not return_dict: if not return_dict:
return tuple( return tuple(
v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
) )
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=x, last_hidden_state=hidden_states,
past_key_values=next_cache, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attns, attentions=all_self_attns,
...@@ -637,211 +1005,11 @@ class BartDecoder(nn.Module): ...@@ -637,211 +1005,11 @@ class BartDecoder(nn.Module):
) )
def _reorder_buffer(attn_cache: Dict, new_order) -> Dict:
for k, input_buffer_k in attn_cache.items():
if input_buffer_k is not None:
attn_cache[k] = input_buffer_k.index_select(0, new_order)
return attn_cache
class Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
encoder_decoder_attention=False, # otherwise self_attention
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.encoder_decoder_attention = encoder_decoder_attention
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
def _shape(self, tensor, seq_len, bsz):
return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
def forward(
self,
query,
key: Tensor,
key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Tensor]] = None,
attn_mask: Optional[Tensor] = None,
output_attentions=False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel"""
static_kv: bool = self.encoder_decoder_attention
tgt_len, bsz, embed_dim = query.size()
# get here for encoder decoder cause of static_kv
if layer_state is not None: # reuse k,v and encoder_padding_mask
saved_state = layer_state.get(self.cache_key, {})
if "prev_key" in saved_state and static_kv:
# previous time steps are cached - no need to recompute key and value if they are static
key = None
else:
# this branch is hit by encoder
saved_state = None
q = self.q_proj(query) * self.scaling
if static_kv and key is None: # cross-attention with cache
k = v = None
elif static_kv and key is not None: # cross-attention no prev_key found in cache
k = self.k_proj(key)
v = self.v_proj(key)
else: # self-attention
k = self.k_proj(query)
v = self.v_proj(query)
q = self._shape(q, tgt_len, bsz)
if k is not None:
k = self._shape(k, -1, bsz)
if v is not None:
v = self._shape(v, -1, bsz)
if saved_state:
k, v = self._concat_saved_state(k, v, saved_state, static_kv, bsz)
# Update cache
if isinstance(layer_state, dict):
cached_shape = (bsz, self.num_heads, -1, self.head_dim) # bsz must be first for reorder_cache
layer_state[self.cache_key] = dict(prev_key=k.view(*cached_shape), prev_value=v.view(*cached_shape))
src_len = k.size(1)
assert key_padding_mask is None or key_padding_mask.shape == (bsz, src_len)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
if attn_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# Note: deleted workaround to get around fork/join parallelism not supporting Optional types. on 2020/10/15
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
if output_attentions:
# make sure that attn_weights are included in graph
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
assert v is not None
attn_output = torch.bmm(attn_probs, v)
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
def _concat_saved_state(self, k, v, saved_state, static_kv, bsz) -> Tuple[Tensor]:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
prev_K = saved_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
prev_V = saved_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
new_K = prev_K if static_kv else torch.cat([prev_K, k], dim=1)
new_V = prev_V if static_kv else torch.cat([prev_V, v], dim=1)
return new_K, new_V
class BartClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
# This can trivially be shared with RobertaClassificationHead
def __init__(
self,
input_dim,
inner_dim,
num_classes,
pooler_dropout,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, x):
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = offset
assert padding_idx is not None
num_embeddings += offset
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input_ids, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
if use_cache:
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
else:
# starts at 0, ends at 1-seq_len
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
return super().forward(positions + self.offset)
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
if torch.cuda.is_available():
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a input_ids with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
# Public API
def _get_shape(t):
return getattr(t, "shape", None)
@add_start_docstrings( @add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.", "The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class BartModel(PretrainedBartModel): class BartModel(BartPretrainedModel):
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__(config) super().__init__(config)
...@@ -853,6 +1021,20 @@ class BartModel(PretrainedBartModel): ...@@ -853,6 +1021,20 @@ class BartModel(PretrainedBartModel):
self.init_weights() self.init_weights()
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
...@@ -862,20 +1044,26 @@ class BartModel(PretrainedBartModel): ...@@ -862,20 +1044,26 @@ class BartModel(PretrainedBartModel):
) )
def forward( def forward(
self, self,
input_ids, input_ids=None,
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_outputs: Optional[Tuple] = None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ):
if decoder_input_ids is None: # 4.12.20 (PVP): Not a fan of this "magical" function and
use_cache = False # also wonder how often it's actually used ... keep now
# for backward compatibility
# -> is this used for backward compatibility
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -884,24 +1072,11 @@ class BartModel(PretrainedBartModel): ...@@ -884,24 +1072,11 @@ class BartModel(PretrainedBartModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# make masks if user doesn't supply
if not use_cache:
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
self.config,
input_ids,
decoder_input_ids=decoder_input_ids,
decoder_padding_mask=decoder_attention_mask,
causal_mask_dtype=self.shared.weight.dtype,
)
else:
decoder_padding_mask, causal_mask = None, None
assert decoder_input_ids is not None
if encoder_outputs is None: if encoder_outputs is None:
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -914,14 +1089,14 @@ class BartModel(PretrainedBartModel): ...@@ -914,14 +1089,14 @@ class BartModel(PretrainedBartModel):
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
) )
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
decoder_input_ids, input_ids=decoder_input_ids,
encoder_outputs[0], attention_mask=decoder_attention_mask,
attention_mask, encoder_hidden_states=encoder_outputs[0],
decoder_padding_mask, encoder_attention_mask=attention_mask,
decoder_causal_mask=causal_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -942,39 +1117,40 @@ class BartModel(PretrainedBartModel): ...@@ -942,39 +1117,40 @@ class BartModel(PretrainedBartModel):
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) )
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def get_output_embeddings(self):
return _make_linear_from_emb(self.shared) # make it on the fly
@add_start_docstrings( @add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
) )
class BartForConditionalGeneration(PretrainedBartModel): class BartForConditionalGeneration(BartPretrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"] _keys_to_ignore_on_load_missing = [
r"final_logits_bias",
r"encoder\.version",
r"decoder\.version",
r"lm_head\.weight",
]
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__(config) super().__init__(config)
base_model = BartModel(config) self.model = BartModel(config)
self.model = base_model
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
self.init_weights()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
old_num_tokens = self.model.shared.num_embeddings
new_embeddings = super().resize_token_embeddings(new_num_tokens) new_embeddings = super().resize_token_embeddings(new_num_tokens)
self.model.shared = new_embeddings self._resize_final_logits_bias(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens, old_num_tokens)
return new_embeddings return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int, old_num_tokens: int) -> None: def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
old_num_tokens = self.final_logits_bias.shape[-1]
if new_num_tokens <= old_num_tokens: if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens] new_bias = self.final_logits_bias[:, :new_num_tokens]
else: else:
...@@ -982,17 +1158,25 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -982,17 +1158,25 @@ class BartForConditionalGeneration(PretrainedBartModel):
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias) self.register_buffer("final_logits_bias", new_bias)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(BART_GENERATION_EXAMPLE) @add_end_docstrings(BART_GENERATION_EXAMPLE)
def forward( def forward(
self, self,
input_ids, input_ids=None,
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_outputs=None, encoder_outputs=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None, labels=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -1039,12 +1223,14 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1039,12 +1223,14 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
...@@ -1071,6 +1257,10 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1071,6 +1257,10 @@ class BartForConditionalGeneration(PretrainedBartModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
): ):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
...@@ -1094,21 +1284,14 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1094,21 +1284,14 @@ class BartForConditionalGeneration(PretrainedBartModel):
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
reordered_past = [] def _reorder_buffer(cache: Tuple[torch.Tensor], new_order) -> Dict:
return tuple(past_state.index_select(0, new_order) for past_state in cache)
reordered_past = ()
for layer_past in past: for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn reordered_past += (tuple(_reorder_buffer(cache, beam_idx) for cache in layer_past),)
layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
return reordered_past return reordered_past
def get_encoder(self):
return self.model.encoder
def get_output_embeddings(self):
return _make_linear_from_emb(self.model.shared) # make it on the fly
@add_start_docstrings( @add_start_docstrings(
""" """
...@@ -1117,7 +1300,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1117,7 +1300,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
""", """,
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class BartForSequenceClassification(PretrainedBartModel): class BartForSequenceClassification(BartPretrainedModel):
def __init__(self, config: BartConfig, **kwargs): def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
self.model = BartModel(config) self.model = BartModel(config)
...@@ -1139,11 +1322,13 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1139,11 +1322,13 @@ class BartForSequenceClassification(PretrainedBartModel):
) )
def forward( def forward(
self, self,
input_ids, input_ids=None,
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_outputs=None, encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None, labels=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -1159,22 +1344,33 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1159,22 +1344,33 @@ class BartForSequenceClassification(PretrainedBartModel):
if labels is not None: if labels is not None:
use_cache = False use_cache = False
if input_ids is None and inputs_embeds is not None:
raise NotImplementedError(
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
)
outputs = self.model( outputs = self.model(
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
) )
x = outputs[0] # last hidden state hidden_states = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id) eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1: if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.") raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, :
]
logits = self.classification_head(sentence_representation) logits = self.classification_head(sentence_representation)
loss = None loss = None
...@@ -1206,7 +1402,7 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1206,7 +1402,7 @@ class BartForSequenceClassification(PretrainedBartModel):
""", """,
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class BartForQuestionAnswering(PretrainedBartModel): class BartForQuestionAnswering(BartPretrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1227,13 +1423,15 @@ class BartForQuestionAnswering(PretrainedBartModel): ...@@ -1227,13 +1423,15 @@ class BartForQuestionAnswering(PretrainedBartModel):
) )
def forward( def forward(
self, self,
input_ids, input_ids=None,
attention_mask=None, attention_mask=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
encoder_outputs=None, encoder_outputs=None,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -1259,6 +1457,8 @@ class BartForQuestionAnswering(PretrainedBartModel): ...@@ -1259,6 +1457,8 @@ class BartForQuestionAnswering(PretrainedBartModel):
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -1308,39 +1508,3 @@ class BartForQuestionAnswering(PretrainedBartModel): ...@@ -1308,39 +1508,3 @@ class BartForQuestionAnswering(PretrainedBartModel):
encoder_hidden_states=outputs.encoder_hidden_states, encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
) )
class SinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions, embedding_dim, padding_idx=None):
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter):
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
n_pos, dim = out.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
return out
@torch.no_grad()
def forward(self, input_ids, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_ids.shape[:2]
if use_cache:
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
else:
# starts at 0, ends at 1-seq_len
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
return super().forward(positions)
...@@ -577,9 +577,9 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -577,9 +577,9 @@ class TFBartDecoder(tf.keras.layers.Layer):
encoder_padding_mask = invert_mask(encoder_padding_mask) encoder_padding_mask = invert_mask(encoder_padding_mask)
# embed positions # embed positions
positions = self.embed_positions(input_ids, use_cache=use_cache) positions = self.embed_positions(input_ids, use_cache=(use_cache and decoder_cached_states is not None))
if use_cache: if use_cache and decoder_cached_states is not None:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
positions = positions[:, -1:] positions = positions[:, -1:]
...@@ -964,7 +964,7 @@ class TFBartModel(TFPretrainedBartModel): ...@@ -964,7 +964,7 @@ class TFBartModel(TFPretrainedBartModel):
else self.config.output_hidden_states else self.config.output_hidden_states
) )
if not inputs["use_cache"]: if not use_cache or past_key_values is None:
inputs["decoder_input_ids"], decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs( inputs["decoder_input_ids"], decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs(
inputs["input_ids"], inputs["input_ids"],
decoder_input_ids=inputs["decoder_input_ids"], decoder_input_ids=inputs["decoder_input_ids"],
...@@ -1154,6 +1154,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel): ...@@ -1154,6 +1154,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
assert ( assert (
decoder_cached_states decoder_cached_states
), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past" ), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past"
assert isinstance( assert isinstance(
encoder_outputs, TFBaseModelOutput encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}." ), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
......
...@@ -813,9 +813,6 @@ class T5Stack(T5PreTrainedModel): ...@@ -813,9 +813,6 @@ class T5Stack(T5PreTrainedModel):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
def get_output_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, new_embeddings): def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings self.embed_tokens = new_embeddings
......
...@@ -450,6 +450,15 @@ class BartModel: ...@@ -450,6 +450,15 @@ class BartModel:
requires_pytorch(self) requires_pytorch(self)
class BartPretrainedModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class PretrainedBartModel: class PretrainedBartModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import tempfile import tempfile
import unittest import unittest
...@@ -49,37 +50,68 @@ if is_torch_available(): ...@@ -49,37 +50,68 @@ if is_torch_available():
pipeline, pipeline,
) )
from transformers.models.bart.modeling_bart import ( from transformers.models.bart.modeling_bart import (
SinusoidalPositionalEmbedding, BartDecoder,
_prepare_bart_decoder_inputs, BartEncoder,
invert_mask, BartSinusoidalPositionalEmbedding,
shift_tokens_right, shift_tokens_right,
) )
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.""" PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
def prepare_bart_inputs_dict(
config,
input_ids,
attention_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
@require_torch @require_torch
class ModelTester: class BartModelTester:
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13,
seq_length=7,
is_training=True,
use_labels=False,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
): ):
self.parent = parent self.parent = parent
self.batch_size = 13 self.batch_size = batch_size
self.seq_length = 7 self.seq_length = seq_length
self.is_training = True self.is_training = is_training
self.use_labels = False self.use_labels = use_labels
self.vocab_size = 99 self.vocab_size = vocab_size
self.hidden_size = 16 self.hidden_size = hidden_size
self.num_hidden_layers = 2 self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = 4 self.num_attention_heads = num_attention_heads
self.intermediate_size = 4 self.intermediate_size = intermediate_size
self.hidden_act = "gelu" self.hidden_act = hidden_act
self.hidden_dropout_prob = 0.1 self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = 0.1 self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = 20 self.max_position_embeddings = max_position_embeddings
self.eos_token_id = 2 self.eos_token_id = eos_token_id
self.pad_token_id = 1 self.pad_token_id = pad_token_id
self.bos_token_id = 0 self.bos_token_id = bos_token_id
torch.manual_seed(0) torch.manual_seed(0)
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
...@@ -111,21 +143,67 @@ class ModelTester: ...@@ -111,21 +143,67 @@ class ModelTester:
config, inputs_dict = self.prepare_config_and_inputs() config, inputs_dict = self.prepare_config_and_inputs()
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"] inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
inputs_dict["decoder_attention_mask"] = inputs_dict["attention_mask"] inputs_dict["decoder_attention_mask"] = inputs_dict["attention_mask"]
inputs_dict["use_cache"] = False
return config, inputs_dict return config, inputs_dict
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = BartModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
def prepare_bart_inputs_dict( # first forward pass
config, outputs = model(input_ids, use_cache=True)
input_ids,
attention_mask=None, output, past_key_values = outputs.to_tuple()
):
if attention_mask is None: # create hypothetical multiple next token and extent to next_input_ids
attention_mask = input_ids.ne(config.pad_token_id) next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
return {
"input_ids": input_ids, # append to next input_ids and
"attention_mask": attention_mask, next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
}
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = BartModel(config=config).to(torch_device).eval()
outputs = model(**inputs_dict)
encoder_last_hidden_state = outputs.encoder_last_hidden_state
last_hidden_state = outputs.last_hidden_state
with tempfile.TemporaryDirectory() as tmpdirname:
encoder = model.get_encoder()
encoder.save_pretrained(tmpdirname)
encoder = BartEncoder.from_pretrained(tmpdirname).to(torch_device)
encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
0
]
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
decoder = model.get_decoder()
decoder.save_pretrained(tmpdirname)
decoder = BartDecoder.from_pretrained(tmpdirname).to(torch_device)
last_hidden_state_2 = decoder(
input_ids=inputs_dict["decoder_input_ids"],
attention_mask=inputs_dict["decoder_attention_mask"],
encoder_hidden_states=encoder_last_hidden_state,
encoder_attention_mask=inputs_dict["attention_mask"],
)[0]
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
@require_torch @require_torch
...@@ -142,7 +220,7 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -142,7 +220,7 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
test_missing_keys = False test_missing_keys = False
def setUp(self): def setUp(self):
self.model_tester = ModelTester(self) self.model_tester = BartModelTester(self)
self.config_tester = ConfigTester(self, config_class=BartConfig) self.config_tester = ConfigTester(self, config_class=BartConfig)
def test_config(self): def test_config(self):
...@@ -169,23 +247,25 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -169,23 +247,25 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs() config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.use_cache = False config.use_cache = False
inputs_dict["input_ids"][:, -2:] = config.pad_token_id inputs_dict["input_ids"][:, -2:] = config.pad_token_id
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
config, inputs_dict["input_ids"]
)
model = BartModel(config).to(torch_device).eval()
model = BartModel(config).to(torch_device).eval()
decoder_features_with_created_mask = model(**inputs_dict)[0] decoder_features_with_created_mask = model(**inputs_dict)[0]
decoder_input_ids = shift_tokens_right(inputs_dict["input_ids"], config.pad_token_id)
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
decoder_attention_mask[:, 0] = decoder_attention_mask[:, 1]
decoder_features_with_passed_mask = model( decoder_features_with_passed_mask = model(
decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict decoder_attention_mask=decoder_attention_mask, decoder_input_ids=decoder_input_ids, **inputs_dict
)[0] )[0]
assert_tensors_close(decoder_features_with_passed_mask, decoder_features_with_created_mask) assert_tensors_close(decoder_features_with_passed_mask, decoder_features_with_created_mask)
useless_mask = torch.zeros_like(decoder_attn_mask) useless_mask = torch.zeros_like(decoder_attention_mask)
decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0] decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0]
self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions
self.assertEqual( self.assertEqual(
decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model) decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model)
) )
if decoder_attn_mask.min().item() < -1e3: # some tokens were masked if decoder_attention_mask.min().item() == 0: # some tokens were masked
self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item()) self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item())
# Test different encoder attention masks # Test different encoder attention masks
...@@ -204,13 +284,43 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -204,13 +284,43 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], []) self.assertEqual(info["missing_keys"], [])
@unittest.skip("Passing inputs_embeds not implemented for Bart.") def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
# BartForSequenceClassification does not support inputs_embeds
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment") for model_class in (BartModel, BartForConditionalGeneration, BartForQuestionAnswering):
def test_resize_embeddings_untied(self): model = model_class(config)
pass model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
del inputs["input_ids"]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
wte = model.get_input_embeddings()
if not self.is_encoder_decoder:
inputs["inputs_embeds"] = wte(input_ids)
else:
inputs["inputs_embeds"] = wte(encoder_input_ids)
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
model(**inputs)[0]
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
...@@ -386,20 +496,6 @@ class BartHeadTests(unittest.TestCase): ...@@ -386,20 +496,6 @@ class BartHeadTests(unittest.TestCase):
model = BartForConditionalGeneration(config).eval().to(torch_device) model = BartForConditionalGeneration(config).eval().to(torch_device)
model(**model.dummy_inputs) model(**model.dummy_inputs)
def test_prepare_bart_decoder_inputs(self):
config, *_ = self._get_config_and_data()
input_ids = _long_tensor(([4, 4, 2]))
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
ignore = float("-inf")
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids
)
expected_causal_mask = torch.tensor(
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
).to(input_ids.device)
self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())
def test_resize_tokens_embeddings_more(self): def test_resize_tokens_embeddings_more(self):
config, input_ids, _ = self._get_config_and_data() config, input_ids, _ = self._get_config_and_data()
...@@ -470,14 +566,14 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -470,14 +566,14 @@ class BartModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
@slow @slow
def test_bart_base_mask_filling(self): def test_base_mask_filling(self):
pbase = pipeline(task="fill-mask", model="facebook/bart-base") pbase = pipeline(task="fill-mask", model="facebook/bart-base")
src_text = [" I went to the <mask>."] src_text = [" I went to the <mask>."]
results = [x["token_str"] for x in pbase(src_text)] results = [x["token_str"] for x in pbase(src_text)]
assert "Ġbathroom" in results assert "Ġbathroom" in results
@slow @slow
def test_bart_large_mask_filling(self): def test_large_mask_filling(self):
plarge = pipeline(task="fill-mask", model="facebook/bart-large") plarge = pipeline(task="fill-mask", model="facebook/bart-large")
src_text = [" I went to the <mask>."] src_text = [" I went to the <mask>."]
results = [x["token_str"] for x in plarge(src_text)] results = [x["token_str"] for x in plarge(src_text)]
...@@ -608,7 +704,7 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -608,7 +704,7 @@ class BartModelIntegrationTests(unittest.TestCase):
@require_torch @require_torch
class TestSinusoidalPositionalEmbeddings(unittest.TestCase): class TestBartSinusoidalPositionalEmbeddings(unittest.TestCase):
desired_weights = [ desired_weights = [
[0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
[0.84147096, 0.82177866, 0.80180490, 0.78165019, 0.76140374], [0.84147096, 0.82177866, 0.80180490, 0.78165019, 0.76140374],
...@@ -616,38 +712,30 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase): ...@@ -616,38 +712,30 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
] ]
def test_positional_emb_cache_logic(self): def test_positional_emb_cache_logic(self):
pad = 1 emb1 = BartSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6, padding_idx=1).to(torch_device)
input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device) no_cache = emb1((4, 10), past_key_values_length=0)
emb1 = SinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6, padding_idx=pad).to(torch_device) yes_cache = emb1((4, 10), past_key_values_length=2)
no_cache = emb1(input_ids, use_cache=False)
yes_cache = emb1(input_ids, use_cache=True) self.assertTrue(no_cache.shape == yes_cache.shape == (10, 6))
self.assertEqual((1, 1, 6), yes_cache.shape) # extra dim to allow broadcasting, feel free to delete! self.assertListEqual(no_cache[2:].tolist(), yes_cache[:-2].tolist())
self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist())
def test_odd_embed_dim(self): def test_odd_embed_dim(self):
# odd embedding_dim is allowed # odd embedding_dim is allowed
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device) BartSinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
# odd num_positions is allowed # odd num_positions is allowed
SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device) BartSinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)
def test_positional_emb_weights_against_marian(self): def test_positional_emb_weights_against_marian(self):
pad = 1 pad = 1
emb1 = SinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512, padding_idx=pad).to(torch_device) emb1 = BartSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512, padding_idx=pad).to(
torch_device
)
weights = emb1.weight.data[:3, :5].tolist() weights = emb1.weight.data[:3, :5].tolist()
for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)): for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)):
for j in range(5): for j in range(5):
self.assertAlmostEqual(expected_weight[j], actual_weight[j], places=3) self.assertAlmostEqual(expected_weight[j], actual_weight[j], places=3)
# test that forward pass is just a lookup, there is no ignore padding logic
input_ids = torch.tensor([[4, 10, pad, pad, pad]], dtype=torch.long, device=torch_device)
no_cache_pad_zero = emb1(input_ids)
self.assertTrue(
torch.allclose(
torch.tensor(self.desired_weights, device=torch_device), no_cache_pad_zero[:3, :5], atol=1e-3
)
)
def test_child_config_equivalence(self): def test_child_config_equivalence(self):
"""Test that configs associated with children of BartForConditionalGeneration are identical.""" """Test that configs associated with children of BartForConditionalGeneration are identical."""
child_classes = [BlenderbotConfig, MBartConfig, MarianConfig, PegasusConfig] child_classes = [BlenderbotConfig, MBartConfig, MarianConfig, PegasusConfig]
......
...@@ -104,9 +104,6 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase): ...@@ -104,9 +104,6 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
self.model_tester = BlenderbotModelTester(self) self.model_tester = BlenderbotModelTester(self)
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig) self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
def test_inputs_embeds(self):
pass
def test_initialization_module(self): def test_initialization_module(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = BlenderbotForConditionalGeneration(config).model model = BlenderbotForConditionalGeneration(config).model
......
...@@ -302,6 +302,8 @@ class ModelTesterMixin: ...@@ -302,6 +302,8 @@ class ModelTesterMixin:
# Question Answering model returns start_logits and end_logits # Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(): if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
correct_outlen += 1 # start_logits and end_logits instead of only 1 output correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
self.assertEqual(out_len, correct_outlen) self.assertEqual(out_len, correct_outlen)
...@@ -386,7 +388,7 @@ class ModelTesterMixin: ...@@ -386,7 +388,7 @@ class ModelTesterMixin:
try: try:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
model.config.use_cache = False # TODO: this should be deleted after bug #7474 is solved model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
input_ids = inputs["input_ids"] input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"] attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"] decoder_input_ids = inputs["decoder_input_ids"]
...@@ -1020,7 +1022,6 @@ class ModelTesterMixin: ...@@ -1020,7 +1022,6 @@ class ModelTesterMixin:
) )
def test_inputs_embeds(self): def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
...@@ -37,7 +37,7 @@ from transformers.testing_utils import ( ...@@ -37,7 +37,7 @@ from transformers.testing_utils import (
torch_device, torch_device,
) )
from .test_modeling_bart import ModelTester as BartModelTester from .test_modeling_bart import BartModelTester
from .test_modeling_dpr import DPRModelTester from .test_modeling_dpr import DPRModelTester
from .test_modeling_t5 import T5ModelTester from .test_modeling_t5 import T5ModelTester
......
...@@ -344,13 +344,6 @@ class TFModelTesterMixin: ...@@ -344,13 +344,6 @@ class TFModelTesterMixin:
tf_hidden_states[pt_nans] = 0 tf_hidden_states[pt_nans] = 0
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
# Debug info (remove when fixed)
if max_diff >= 4e-2:
print("===")
print(model_class)
print(config)
print(inputs_dict)
print(pt_inputs_dict)
self.assertLessEqual(max_diff, 4e-2) self.assertLessEqual(max_diff, 4e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions # Check we can load pt model in tf and vice-versa with checkpoint => model functions
......
...@@ -28,6 +28,8 @@ PATH_TO_DOC = "docs/source/model_doc" ...@@ -28,6 +28,8 @@ PATH_TO_DOC = "docs/source/model_doc"
# Update this list for models that are not tested with a comment explaining the reason it should not be. # Update this list for models that are not tested with a comment explaining the reason it should not be.
# Being in this list is an exception and should **not** be the rule. # Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = [ IGNORE_NON_TESTED = [
"BartDecoder", # Building part of bigger (tested) model.
"BartEncoder", # Building part of bigger (tested) model.
"BertLMHeadModel", # Needs to be setup as decoder. "BertLMHeadModel", # Needs to be setup as decoder.
"DPREncoder", # Building part of bigger (tested) model. "DPREncoder", # Building part of bigger (tested) model.
"DPRSpanPredictor", # Building part of bigger (tested) model. "DPRSpanPredictor", # Building part of bigger (tested) model.
...@@ -58,9 +60,11 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ ...@@ -58,9 +60,11 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# Update this list for models that are not documented with a comment explaining the reason it should not be. # Update this list for models that are not documented with a comment explaining the reason it should not be.
# Being in this list is an exception and should **not** be the rule. # Being in this list is an exception and should **not** be the rule.
IGNORE_NON_DOCUMENTED = [ IGNORE_NON_DOCUMENTED = [
"BartDecoder", # Building part of bigger (documented) model.
"BartEncoder", # Building part of bigger (documented) model.
"DPREncoder", # Building part of bigger (documented) model. "DPREncoder", # Building part of bigger (documented) model.
"DPRSpanPredictor", # Building part of bigger (documented) model. "DPRSpanPredictor", # Building part of bigger (documented) model.
"T5Stack", # Building part of bigger (tested) model. "T5Stack", # Building part of bigger (documented) model.
"TFDPREncoder", # Building part of bigger (documented) model. "TFDPREncoder", # Building part of bigger (documented) model.
"TFDPRSpanPredictor", # Building part of bigger (documented) model. "TFDPRSpanPredictor", # Building part of bigger (documented) model.
] ]
...@@ -78,6 +82,8 @@ MODEL_NAME_TO_DOC_FILE = { ...@@ -78,6 +82,8 @@ MODEL_NAME_TO_DOC_FILE = {
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and # Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule. # should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = [ IGNORE_NON_AUTO_CONFIGURED = [
"BartDecoder",
"BartEncoder",
"DPRContextEncoder", "DPRContextEncoder",
"DPREncoder", "DPREncoder",
"DPRReader", "DPRReader",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment