"...resnet50_tensorflow.git" did not exist on "b0e3311267886837ba8aa73510b1a3eaa4bebb3a"
Unverified Commit 917dbb15 authored by Sergey Mkrtchyan's avatar Sergey Mkrtchyan Committed by GitHub
Browse files

Fix DPRReaderTokenizer's attention_mask (#9663)

* Fix the attention_mask in DPRReaderTokenizer

* Add an integration test for DPRReader inference

* Run make style
parent 12c1b5b8
......@@ -251,7 +251,9 @@ class CustomDPRReaderTokenizerMixin:
]
}
if return_attention_mask is not False:
attention_mask = [input_ids != self.pad_token_id for input_ids in encoded_inputs["input_ids"]]
attention_mask = []
for input_ids in encoded_inputs["input_ids"]:
attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids])
encoded_inputs["attention_mask"] = attention_mask
return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors)
......
......@@ -252,7 +252,9 @@ class CustomDPRReaderTokenizerMixin:
]
}
if return_attention_mask is not False:
attention_mask = [input_ids != self.pad_token_id for input_ids in encoded_inputs["input_ids"]]
attention_mask = []
for input_ids in encoded_inputs["input_ids"]:
attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids])
encoded_inputs["attention_mask"] = attention_mask
return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors)
......
......@@ -26,7 +26,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
if is_torch_available():
import torch
from transformers import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
from transformers import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader, DPRReaderTokenizer
from transformers.models.dpr.modeling_dpr import (
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
......@@ -260,3 +260,33 @@ class DPRModelIntegrationTest(unittest.TestCase):
device=torch_device,
)
self.assertTrue(torch.allclose(output[:, :10], expected_slice, atol=1e-4))
@slow
def test_reader_inference(self):
tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")
encoded_inputs = tokenizer(
questions="What is love ?",
titles="Haddaway",
texts="What Is Love is a song recorded by the artist Haddaway",
padding=True,
return_tensors="pt",
)
outputs = model(**encoded_inputs)
# compare the actual values for a slice.
expected_start_logits = torch.tensor(
[[-10.3005, -10.7765, -11.4872, -11.6841, -11.9312, -10.3002, -9.8544, -11.7378, -12.0821, -10.2975]],
dtype=torch.float,
device=torch_device,
)
expected_end_logits = torch.tensor(
[[-11.0684, -11.7041, -11.5397, -10.3465, -10.8791, -6.8443, -11.9959, -11.0364, -10.0096, -6.8405]],
dtype=torch.float,
device=torch_device,
)
self.assertTrue(torch.allclose(outputs.start_logits[:, :10], expected_start_logits, atol=1e-4))
self.assertTrue(torch.allclose(outputs.end_logits[:, :10], expected_end_logits, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment