Unverified Commit 0804d077 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

correct attention mask (#7373)

parent a8cbc426
...@@ -115,11 +115,15 @@ def evaluate_batch_retrieval(args, rag_model, questions): ...@@ -115,11 +115,15 @@ def evaluate_batch_retrieval(args, rag_model, questions):
def evaluate_batch_e2e(args, rag_model, questions): def evaluate_batch_e2e(args, rag_model, questions):
with torch.no_grad(): with torch.no_grad():
input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus( inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
questions, return_tensors="pt", padding=True, truncation=True questions, return_tensors="pt", padding=True, truncation=True
)["input_ids"].to(args.device) )
input_ids = inputs_dict.input_ids.to(args.device)
attention_mask = inputs_dict.attention_mask.to(args.device)
outputs = rag_model.generate( # rag_model overwrites generate outputs = rag_model.generate( # rag_model overwrites generate
input_ids, input_ids,
attention_mask=attention_mask,
num_beams=args.num_beams, num_beams=args.num_beams,
min_length=args.min_length, min_length=args.min_length,
max_length=args.max_length, max_length=args.max_length,
......
...@@ -814,7 +814,8 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -814,7 +814,8 @@ class RagSequenceForGeneration(RagPreTrainedModel):
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
input_ids, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
context_input_ids=None, context_input_ids=None,
do_deduplication=None, # defaults to True do_deduplication=None, # defaults to True
num_return_sequences=None, # defaults to 1 num_return_sequences=None, # defaults to 1
...@@ -859,7 +860,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -859,7 +860,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
# TODO(patrick) - clean up generate here # TODO(patrick) - clean up generate here
if self.retriever is not None and context_input_ids is None: if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids)[0] question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
context_input_ids = self.retriever( context_input_ids = self.retriever(
input_ids, input_ids,
question_hidden_states.cpu().detach().to(torch.float32).numpy(), question_hidden_states.cpu().detach().to(torch.float32).numpy(),
...@@ -1180,6 +1181,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1180,6 +1181,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
def generate( def generate(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
context_input_ids=None, context_input_ids=None,
context_attention_mask=None, context_attention_mask=None,
doc_scores=None, doc_scores=None,
...@@ -1293,7 +1295,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1293,7 +1295,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
# retrieve docs # retrieve docs
if self.retriever is not None and context_input_ids is None: if self.retriever is not None and context_input_ids is None:
question_hidden_states = self.question_encoder(input_ids)[0] question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
out = self.retriever( out = self.retriever(
input_ids, input_ids,
question_hidden_states.cpu().detach().to(torch.float32).numpy(), question_hidden_states.cpu().detach().to(torch.float32).numpy(),
......
...@@ -416,7 +416,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase): ...@@ -416,7 +416,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
t5_config_and_inputs = generator_tester.prepare_config_and_inputs() t5_config_and_inputs = generator_tester.prepare_config_and_inputs()
(question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs (question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
# import ipdb; ipdb.set_trace()
(generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs (generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs
config = RagConfig.from_question_encoder_generator_configs( config = RagConfig.from_question_encoder_generator_configs(
question_encoder_config, question_encoder_config,
...@@ -620,18 +619,21 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -620,18 +619,21 @@ class RagModelIntegrationTests(unittest.TestCase):
questions = [ questions = [
"who sings does he love me with reba", "who sings does he love me with reba",
"how many pages is invisible man by ralph ellison", "how many pages is invisible man by ralph ellison",
"what",
] ]
input_ids = rag_question_encoder_tokenizer.batch_encode_plus( input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions, questions,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
).input_ids )
input_ids = input_ids.to(torch_device) input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_token.generate( output_ids = rag_token.generate(
input_ids, input_ids,
attention_mask=attention_mask,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id, decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
num_beams=4, num_beams=4,
num_return_sequences=1, num_return_sequences=1,
...@@ -641,13 +643,16 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -641,13 +643,16 @@ class RagModelIntegrationTests(unittest.TestCase):
# sequence generate test # sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
# Expected outputs as given by model at integration time. # Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the' EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the'
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man' EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man'
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
@slow @slow
def test_rag_sequence_generate_batch(self): def test_rag_sequence_generate_batch(self):
...@@ -669,18 +674,22 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -669,18 +674,22 @@ class RagModelIntegrationTests(unittest.TestCase):
questions = [ questions = [
"who sings does he love me with reba", "who sings does he love me with reba",
"how many pages is invisible man by ralph ellison", "how many pages is invisible man by ralph ellison",
"what",
] ]
input_ids = rag_question_encoder_tokenizer.batch_encode_plus(
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions, questions,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
).input_ids )
input_ids = input_ids.to(torch_device) input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_sequence.generate( output_ids = rag_sequence.generate(
input_ids, input_ids,
attention_mask=attention_mask,
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id, decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
num_beams=4, num_beams=4,
num_return_sequences=1, num_return_sequences=1,
...@@ -690,13 +699,16 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -690,13 +699,16 @@ class RagModelIntegrationTests(unittest.TestCase):
# sequence generate test # sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
# Expected outputs as given by model at integration time. # Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"' EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"'
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the' EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the'
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
@slow @slow
def test_rag_sequence_generate_beam(self): def test_rag_sequence_generate_beam(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment