Unverified Commit f15f0871 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1764 from DomHudson/bug-fix-1761

Bug-fix: Roberta Embeddings Not Masked
parents fae4d1c2 3e52915f
......@@ -51,24 +51,44 @@ class RobertaEmbeddings(BertEmbeddings):
padding_idx=self.padding_idx)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if position_ids is None:
# Position numbers begin at padding_idx+1. Padding symbols are ignored.
# cf. fairseq's `utils.make_positions`
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids).to(input_ids.device)
else:
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
return super(RobertaEmbeddings, self).forward(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds)
def create_position_ids_from_input_ids(self, x):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
mask = x.ne(self.padding_idx).long()
incremental_indicies = torch.cumsum(mask, dim=1) * mask
return incremental_indicies + self.padding_idx
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
""" We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param torch.Tensor inputs_embeds:
:return torch.Tensor:
"""
input_shape = inputs_embeds.size()[:-1]
sequence_length = input_shape[1]
position_ids = torch.arange(self.padding_idx+1, sequence_length+self.padding_idx+1, dtype=torch.long,
device=inputs_embeds.device)
return position_ids.unsqueeze(0).expand(input_shape)
ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in
`RoBERTa: A Robustly Optimized BERT Pretraining Approach`_
......
......@@ -24,6 +24,7 @@ if is_torch_available():
import torch
from transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM,
RobertaForSequenceClassification, RobertaForTokenClassification)
from transformers.modeling_roberta import RobertaEmbeddings
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ids_tensor)
......@@ -202,6 +203,57 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
model = RobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
def test_create_position_ids_respects_padding_index(self):
""" Ensure that the default position ids only assign a sequential . This is a regression
test for https://github.com/huggingface/transformers/issues/1761
The position ids should be masked with the embedding object's padding index. Therefore, the
first available non-padding position index is RobertaEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
model = RobertaEmbeddings(config=config)
input_ids = torch.as_tensor([[12, 31, 13, model.padding_idx]])
expected_positions = torch.as_tensor([[
0 + model.padding_idx + 1,
1 + model.padding_idx + 1,
2 + model.padding_idx + 1,
model.padding_idx
]])
position_ids = model.create_position_ids_from_input_ids(input_ids)
self.assertEqual(
position_ids.shape,
expected_positions.shape
)
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
def test_create_position_ids_from_inputs_embeds(self):
""" Ensure that the default position ids only assign a sequential . This is a regression
test for https://github.com/huggingface/transformers/issues/1761
The position ids should be masked with the embedding object's padding index. Therefore, the
first available non-padding position index is RobertaEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
embeddings = RobertaEmbeddings(config=config)
inputs_embeds = torch.Tensor(2, 4, 30)
expected_single_positions = [
0 + embeddings.padding_idx + 1,
1 + embeddings.padding_idx + 1,
2 + embeddings.padding_idx + 1,
3 + embeddings.padding_idx + 1,
]
expected_positions = torch.as_tensor([expected_single_positions, expected_single_positions])
position_ids = embeddings.create_position_ids_from_inputs_embeds(inputs_embeds)
self.assertEqual(
position_ids.shape,
expected_positions.shape
)
self.assertTrue(
torch.all(torch.eq(position_ids, expected_positions))
)
class RobertaModelIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment