Unverified Commit 0804d077 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

correct attention mask (#7373)

parent a8cbc426
......@@ -115,11 +115,15 @@ def evaluate_batch_retrieval(args, rag_model, questions):
def evaluate_batch_e2e(args, rag_model, questions):
with torch.no_grad():
input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
questions, return_tensors="pt", padding=True, truncation=True
)["input_ids"].to(args.device)
)
input_ids = inputs_dict.input_ids.to(args.device)
attention_mask = inputs_dict.attention_mask.to(args.device)
outputs = rag_model.generate( # rag_model overwrites generate
input_ids,
attention_mask=attention_mask,
num_beams=args.num_beams,
min_length=args.min_length,
max_length=args.max_length,
......
......@@ -814,7 +814,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
@torch.no_grad()
def generate(
self,
input_ids,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
context_input_ids=None,
do_deduplication=None, # defaults to True
num_return_sequences=None, # defaults to 1
......@@ -859,7 +860,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
# TODO(patrick) - clean up generate here
if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids)[0]
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
context_input_ids = self.retriever(
input_ids,
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
......@@ -1180,6 +1181,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
context_input_ids=None,
context_attention_mask=None,
doc_scores=None,
......@@ -1293,7 +1295,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
# retrieve docs
if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids)[0]
question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
out = self.retriever(
input_ids,
question_hidden_states.cpu().detach().to(torch.float32).numpy(),
......
......@@ -416,7 +416,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
t5_config_and_inputs = generator_tester.prepare_config_and_inputs()
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
# import ipdb; ipdb.set_trace()
(generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs
config = RagConfig.from_question_encoder_generator_configs(
question_encoder_config,
......@@ -620,18 +619,21 @@ class RagModelIntegrationTests(unittest.TestCase):
questions = [
"who sings does he love me with reba",
"how many pages is invisible man by ralph ellison",
"what",
]
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt",
padding=True,
truncation=True,
).input_ids
)
input_ids = input_ids.to(torch_device)
input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_token.generate(
input_ids,
attention_mask=attention_mask,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
num_beams=4,
num_return_sequences=1,
......@@ -641,13 +643,16 @@ class RagModelIntegrationTests(unittest.TestCase):
# sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
@slow
def test_rag_sequence_generate_batch(self):
......@@ -669,18 +674,22 @@ class RagModelIntegrationTests(unittest.TestCase):
questions = [
"who sings does he love me with reba",
"how many pages is invisible man by ralph ellison",
"what",
]
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt",
padding=True,
truncation=True,
).input_ids
)
input_ids = input_ids.to(torch_device)
input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_sequence.generate(
input_ids,
attention_mask=attention_mask,
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
num_beams=4,
num_return_sequences=1,
......@@ -690,13 +699,16 @@ class RagModelIntegrationTests(unittest.TestCase):
# sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
# Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
@slow
def test_rag_sequence_generate_beam(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment