Unverified Commit 58918c76 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[bart] add config.extra_pos_embeddings to facilitate reuse (#5190)

parent b28b5371
......@@ -41,6 +41,7 @@ class BartConfig(PretrainedConfig):
def __init__(
self,
activation_dropout=0.0,
extra_pos_embeddings=2,
activation_function="gelu",
vocab_size=50265,
d_model=1024,
......@@ -118,6 +119,9 @@ class BartConfig(PretrainedConfig):
# Classifier stuff
self.classif_dropout = classifier_dropout
# pos embedding offset
self.extra_pos_embeddings = self.pad_token_id + 1
@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
......
......@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
from .activations import ACT2FN
from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__)
......@@ -96,6 +96,7 @@ BART_INPUTS_DOCSTRING = r"""
def invert_mask(attention_mask):
"""Turns 1->0, 0->1, False->True, True-> False"""
assert attention_mask.dim() == 2
return attention_mask.eq(0)
......@@ -261,7 +262,7 @@ class BartEncoder(nn.Module):
)
else:
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, embed_dim, self.padding_idx,
config.max_position_embeddings, embed_dim, self.padding_idx, config.extra_pos_embeddings,
)
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
......@@ -435,7 +436,7 @@ class BartDecoder(nn.Module):
)
else:
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, config.d_model, self.padding_idx,
config.max_position_embeddings, config.d_model, self.padding_idx, config.extra_pos_embeddings,
)
self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.decoder_layers)]
......@@ -745,23 +746,23 @@ class LearnedPositionalEmbedding(nn.Embedding):
position ids are passed to the forward function.
"""
def __init__(
self, num_embeddings: int, embedding_dim: int, padding_idx: int,
):
# if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = offset
assert padding_idx is not None
num_embeddings += padding_idx + 1 # WHY?
num_embeddings += offset
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input, use_cache=False):
def forward(self, input_ids, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
if use_cache: # the position is our current step in the decoded sequence
pos = int(self.padding_idx + input.size(1))
positions = input.data.new(1, 1).fill_(pos)
bsz, seq_len = input_ids.shape[:2]
if use_cache:
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
else:
positions = create_position_ids_from_input_ids(input, self.padding_idx)
return super().forward(positions)
# starts at 0, ends at 1-seq_len
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
return super().forward(positions + self.offset)
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
......
......@@ -26,7 +26,6 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu
from .modeling_utils import create_position_ids_from_input_ids
logger = logging.getLogger(__name__)
......@@ -733,3 +732,17 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
def create_position_ids_from_input_ids(input_ids, padding_idx):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx
......@@ -2090,20 +2090,6 @@ class SequenceSummary(nn.Module):
return output
def create_position_ids_from_input_ids(input_ids, padding_idx):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx
def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
......
......@@ -34,9 +34,8 @@ if is_torch_available():
RobertaForSequenceClassification,
RobertaForTokenClassification,
)
from transformers.modeling_roberta import RobertaEmbeddings
from transformers.modeling_roberta import RobertaEmbeddings, create_position_ids_from_input_ids
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_utils import create_position_ids_from_input_ids
class RobertaModelTester:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment