Unverified Commit 58918c76 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[bart] add config.extra_pos_embeddings to facilitate reuse (#5190)

parent b28b5371
...@@ -41,6 +41,7 @@ class BartConfig(PretrainedConfig): ...@@ -41,6 +41,7 @@ class BartConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
activation_dropout=0.0, activation_dropout=0.0,
extra_pos_embeddings=2,
activation_function="gelu", activation_function="gelu",
vocab_size=50265, vocab_size=50265,
d_model=1024, d_model=1024,
...@@ -118,6 +119,9 @@ class BartConfig(PretrainedConfig): ...@@ -118,6 +119,9 @@ class BartConfig(PretrainedConfig):
# Classifier stuff # Classifier stuff
self.classif_dropout = classifier_dropout self.classif_dropout = classifier_dropout
# pos embedding offset
self.extra_pos_embeddings = self.pad_token_id + 1
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
return self.encoder_attention_heads return self.encoder_attention_heads
......
...@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss ...@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
from .activations import ACT2FN from .activations import ACT2FN
from .configuration_bart import BartConfig from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -96,6 +96,7 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -96,6 +96,7 @@ BART_INPUTS_DOCSTRING = r"""
def invert_mask(attention_mask): def invert_mask(attention_mask):
"""Turns 1->0, 0->1, False->True, True-> False"""
assert attention_mask.dim() == 2 assert attention_mask.dim() == 2
return attention_mask.eq(0) return attention_mask.eq(0)
...@@ -261,7 +262,7 @@ class BartEncoder(nn.Module): ...@@ -261,7 +262,7 @@ class BartEncoder(nn.Module):
) )
else: else:
self.embed_positions = LearnedPositionalEmbedding( self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, embed_dim, self.padding_idx, config.max_position_embeddings, embed_dim, self.padding_idx, config.extra_pos_embeddings,
) )
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
...@@ -435,7 +436,7 @@ class BartDecoder(nn.Module): ...@@ -435,7 +436,7 @@ class BartDecoder(nn.Module):
) )
else: else:
self.embed_positions = LearnedPositionalEmbedding( self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, config.d_model, self.padding_idx, config.max_position_embeddings, config.d_model, self.padding_idx, config.extra_pos_embeddings,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.decoder_layers)] [DecoderLayer(config) for _ in range(config.decoder_layers)]
...@@ -745,23 +746,23 @@ class LearnedPositionalEmbedding(nn.Embedding): ...@@ -745,23 +746,23 @@ class LearnedPositionalEmbedding(nn.Embedding):
position ids are passed to the forward function. position ids are passed to the forward function.
""" """
def __init__( def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
self, num_embeddings: int, embedding_dim: int, padding_idx: int, # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
): # and adjust num_embeddings appropriately. Other models dont have this hack
# if padding_idx is specified then offset the embedding ids by self.offset = offset
# this index and adjust num_embeddings appropriately
assert padding_idx is not None assert padding_idx is not None
num_embeddings += padding_idx + 1 # WHY? num_embeddings += offset
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input, use_cache=False): def forward(self, input_ids, use_cache=False):
"""Input is expected to be of size [bsz x seqlen].""" """Input is expected to be of size [bsz x seqlen]."""
if use_cache: # the position is our current step in the decoded sequence bsz, seq_len = input_ids.shape[:2]
pos = int(self.padding_idx + input.size(1)) if use_cache:
positions = input.data.new(1, 1).fill_(pos) positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
else: else:
positions = create_position_ids_from_input_ids(input, self.padding_idx) # starts at 0, ends at 1-seq_len
return super().forward(positions) positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
return super().forward(positions + self.offset)
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
......
...@@ -26,7 +26,6 @@ from torch.nn import CrossEntropyLoss, MSELoss ...@@ -26,7 +26,6 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu
from .modeling_utils import create_position_ids_from_input_ids
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -733,3 +732,17 @@ class RobertaForQuestionAnswering(BertPreTrainedModel): ...@@ -733,3 +732,17 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
outputs = (total_loss,) + outputs outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
def create_position_ids_from_input_ids(input_ids, padding_idx):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx
...@@ -2090,20 +2090,6 @@ class SequenceSummary(nn.Module): ...@@ -2090,20 +2090,6 @@ class SequenceSummary(nn.Module):
return output return output
def create_position_ids_from_input_ids(input_ids, padding_idx):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx
def prune_linear_layer(layer, index, dim=0): def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index. """ Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True. Return the pruned layer as a new layer with requires_grad=True.
......
...@@ -34,9 +34,8 @@ if is_torch_available(): ...@@ -34,9 +34,8 @@ if is_torch_available():
RobertaForSequenceClassification, RobertaForSequenceClassification,
RobertaForTokenClassification, RobertaForTokenClassification,
) )
from transformers.modeling_roberta import RobertaEmbeddings from transformers.modeling_roberta import RobertaEmbeddings, create_position_ids_from_input_ids
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_utils import create_position_ids_from_input_ids
class RobertaModelTester: class RobertaModelTester:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment