Unverified Commit cf1c88e0 authored by Quentin Lhoest's avatar Quentin Lhoest Committed by GitHub
Browse files

[RAG] Fix retrieval offset in RAG's HfIndex and better integration tests (#7372)



* Fix retrieval offset in RAG's HfIndex

* update slow tests

* style

* fix new test

* style

* add better tests
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 571c7a11
...@@ -153,4 +153,4 @@ class RagRetrieverTest(TestCase): ...@@ -153,4 +153,4 @@ class RagRetrieverTest(TestCase):
self.assertEqual(len(doc_dicts[0]["id"]), n_docs) self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0]) self.assertListEqual(doc_ids.tolist(), [[1], [0]])
...@@ -203,6 +203,7 @@ class HFIndex: ...@@ -203,6 +203,7 @@ class HFIndex:
dataset_name: str, dataset_name: str,
dataset_split: str, dataset_split: str,
index_name: str, index_name: str,
vector_size: int,
index_path: Optional[str] = None, index_path: Optional[str] = None,
use_dummy_dataset=False, use_dummy_dataset=False,
): ):
...@@ -210,6 +211,7 @@ class HFIndex: ...@@ -210,6 +211,7 @@ class HFIndex:
self.dataset_name = dataset_name self.dataset_name = dataset_name
self.dataset_split = dataset_split self.dataset_split = dataset_split
self.index_name = index_name self.index_name = index_name
self.vector_size = vector_size
self.index_path = index_path self.index_path = index_path
self.use_dummy_dataset = use_dummy_dataset self.use_dummy_dataset = use_dummy_dataset
self._index_initialize = False self._index_initialize = False
...@@ -218,6 +220,7 @@ class HFIndex: ...@@ -218,6 +220,7 @@ class HFIndex:
self.dataset = load_dataset( self.dataset = load_dataset(
self.dataset_name, with_index=False, split=self.dataset_split, dummy=self.use_dummy_dataset self.dataset_name, with_index=False, split=self.dataset_split, dummy=self.use_dummy_dataset
) )
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
def is_initialized(self): def is_initialized(self):
return self._index_initialize return self._index_initialize
...@@ -236,15 +239,19 @@ class HFIndex: ...@@ -236,15 +239,19 @@ class HFIndex:
index_name=self.index_name, index_name=self.index_name,
dummy=self.use_dummy_dataset, dummy=self.use_dummy_dataset,
) )
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
self._index_initialize = True self._index_initialize = True
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]: def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])] return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]: def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
_, docs = self.dataset.get_nearest_examples_batch("embeddings", question_hidden_states, n_docs) _, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs)
ids = [[int(i) for i in doc["id"]] for doc in docs] docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids]
vectors = [doc["embeddings"] for doc in docs] vectors = [doc["embeddings"] for doc in docs]
for i in range(len(vectors)):
if len(vectors[i]) < n_docs:
vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))])
return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d) return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
...@@ -274,7 +281,12 @@ class RagRetriever: ...@@ -274,7 +281,12 @@ class RagRetriever:
) )
if config.index_name == "legacy" if config.index_name == "legacy"
else HFIndex( else HFIndex(
config.dataset, config.dataset_split, config.index_name, config.index_path, config.use_dummy_dataset config.dataset,
config.dataset_split,
config.index_name,
config.retrieval_vector_size,
config.index_path,
config.use_dummy_dataset,
) )
) )
self.generator_tokenizer = generator_tokenizer self.generator_tokenizer = generator_tokenizer
...@@ -384,8 +396,9 @@ class RagRetriever: ...@@ -384,8 +396,9 @@ class RagRetriever:
) )
ids_batched.extend(ids) ids_batched.extend(ids)
vectors_batched.extend(vectors) vectors_batched.extend(vectors)
return np.array(ids_batched), np.array( return (
vectors_batched np.array(ids_batched),
np.array(vectors_batched),
) # shapes (batch_size, n_docs) and (batch_size, n_docs, d) ) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]: def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
......
...@@ -54,6 +54,7 @@ if is_torch_available() and is_datasets_available() and is_faiss_available(): ...@@ -54,6 +54,7 @@ if is_torch_available() and is_datasets_available() and is_faiss_available():
RagRetriever, RagRetriever,
RagSequenceForGeneration, RagSequenceForGeneration,
RagTokenForGeneration, RagTokenForGeneration,
RagTokenizer,
) )
from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_outputs import BaseModelOutput
...@@ -519,7 +520,7 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -519,7 +520,7 @@ class RagModelIntegrationTests(unittest.TestCase):
expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device) expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
_assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE) _assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)
expected_loss = torch.tensor([38.7446]).to(torch_device) expected_loss = torch.tensor([36.7368]).to(torch_device)
_assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE) _assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
@slow @slow
...@@ -558,7 +559,7 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -558,7 +559,7 @@ class RagModelIntegrationTests(unittest.TestCase):
expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device) expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
_assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE) _assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)
expected_loss = torch.tensor([38.7045]).to(torch_device) expected_loss = torch.tensor([36.3557]).to(torch_device)
_assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE) _assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)
@slow @slow
...@@ -594,14 +595,14 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -594,14 +595,14 @@ class RagModelIntegrationTests(unittest.TestCase):
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
# Expected outputs as given by model at integration time. # Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = "The songwriting credits are credited to ABBA" EXPECTED_OUTPUT_TEXT_1 = "\"She's My Kind of Girl"
EXPECTED_OUTPUT_TEXT_2 = 'The songwriting credits are credited to "B' EXPECTED_OUTPUT_TEXT_2 = "\"She's My Kind of Love"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
@slow @slow
def test_rag_token_generate_batch(self): def test_rag_sequence_generate_beam(self):
rag_config = self.get_rag_config() rag_config = self.get_rag_config()
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
...@@ -613,72 +614,64 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -613,72 +614,64 @@ class RagModelIntegrationTests(unittest.TestCase):
generator_tokenizer=rag_decoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer,
) )
rag_token = self.token_model rag_token = self.sequence_model
rag_token.set_retriever(rag_retriever) rag_token.set_retriever(rag_retriever)
questions = [ input_ids = rag_question_encoder_tokenizer(
"who sings does he love me with reba", "who sings does he love me with reba", return_tensors="pt"
"how many pages is invisible man by ralph ellison", ).input_ids
"what",
]
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt",
padding=True,
truncation=True,
)
input_ids = input_dict.input_ids.to(torch_device) input_ids = input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_token.generate( output_ids = rag_token.generate(
input_ids, input_ids,
attention_mask=attention_mask,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id, decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
num_beams=4, num_beams=2,
num_return_sequences=1, num_return_sequences=2,
max_length=10,
) )
# sequence generate test # sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True)
# Expected outputs as given by model at integration time. # Expected outputs as given by model at integration time.
EXPECTED_OUTPUT_TEXT_1 = '"People Need Love" is the' EXPECTED_OUTPUT_TEXT_1 = """\"She's My Kind of Girl\" was released through Epic Records in Japan in March 1972, giving the duo a Top 10 hit. Two more singles were released in Japan, \"En Carousel\" and \"Love Has Its Ways\" Ulvaeus and Andersson persevered with their songwriting and experimented with new sounds and vocal arrangements."""
EXPECTED_OUTPUT_TEXT_2 = '"How many pages is invisible man' EXPECTED_OUTPUT_TEXT_2 = """In September 2018, Björn Ulvaeus revealed that the two new songs, \"I Still Have Faith In You\" and \"Don't Shut Me Down\", would be released no earlier than March 2019. The two new tracks will feature in a TV special set to air later in the year."""
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark"
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3)
@property
def test_data_questions(self):
return [
"who got the first nobel prize in physics",
"when is the next deadpool movie being released",
"which mode is used for short wave broadcast service",
"who is the owner of reading football club",
"when is the next scandal episode coming out",
"when is the last time the philadelphia won the superbowl",
"what is the most current adobe flash player version",
"how many episodes are there in dragon ball z",
"what is the first step in the evolution of the eye",
"where is gall bladder situated in human body",
"what is the main mineral in lithium batteries",
"who is the president of usa right now",
"where do the greasers live in the outsiders",
"panda is a national animal of which country",
"what is the name of manchester united stadium",
]
@slow @slow
def test_rag_sequence_generate_batch(self): def test_rag_sequence_generate_batch(self):
# IMPORTAN: This test fails on GPU, but is fine on CPU -> beam search is very sensible tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
rag_config = self.get_rag_config() retriever = RagRetriever.from_pretrained(
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
"facebook/dpr-question_encoder-single-nq-base"
) )
rag_retriever = RagRetriever( rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
rag_config, torch_device
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
) )
rag_sequence = self.sequence_model input_dict = tokenizer(
rag_sequence.set_retriever(rag_retriever) self.test_data_questions,
questions = [
"who sings does he love me with reba",
"how many pages is invisible man by ralph ellison",
"what",
]
input_dict = rag_question_encoder_tokenizer.batch_encode_plus(
questions,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
truncation=True, truncation=True,
...@@ -690,64 +683,72 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -690,64 +683,72 @@ class RagModelIntegrationTests(unittest.TestCase):
output_ids = rag_sequence.generate( output_ids = rag_sequence.generate(
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
num_beams=4,
num_return_sequences=1,
max_length=10,
) )
# sequence generate test outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True) EXPECTED_OUTPUTS = [
output_text_3 = rag_decoder_tokenizer.decode(output_ids[2], skip_special_tokens=True) " albert einstein",
" june 22, 2018",
# Expected outputs as given by model at integration time. " amplitude modulation",
EXPECTED_OUTPUT_TEXT_1 = '"I Know Him So Well"' " tim besley ( chairman )",
EXPECTED_OUTPUT_TEXT_2 = '"Howl" chronicles the' " june 20, 2018",
EXPECTED_OUTPUT_TEXT_3 = "Otis the Aardvark" " 1980",
" 7.0",
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) " 8",
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) " reticular formation",
self.assertEqual(output_text_3, EXPECTED_OUTPUT_TEXT_3) " walls of the abdomen",
" spodumene",
" obama",
" grainger's compound",
" japan",
" old trafford stadium",
]
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
@slow @slow
def test_rag_sequence_generate_beam(self): def test_rag_token_generate_batch(self):
rag_config = self.get_rag_config() tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained( rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to(
"facebook/dpr-question_encoder-single-nq-base" torch_device
)
rag_retriever = RagRetriever(
rag_config,
question_encoder_tokenizer=rag_question_encoder_tokenizer,
generator_tokenizer=rag_decoder_tokenizer,
) )
rag_token = self.sequence_model input_dict = tokenizer(
rag_token.set_retriever(rag_retriever) self.test_data_questions,
return_tensors="pt",
input_ids = rag_question_encoder_tokenizer( padding=True,
"who sings does he love me with reba", return_tensors="pt" truncation=True,
).input_ids )
input_ids = input_ids.to(torch_device) input_ids = input_dict.input_ids.to(torch_device)
attention_mask = input_dict.attention_mask.to(torch_device)
output_ids = rag_token.generate( output_ids = rag_token.generate(
input_ids, input_ids,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id, attention_mask=attention_mask,
num_beams=2,
num_return_sequences=2,
) )
# sequence generate test
output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)
# Expected outputs as given by model at integration time. outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
EXPECTED_OUTPUT_TEXT_1 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" that day."""
EXPECTED_OUTPUT_TEXT_2 = """ ABBA / small label like Playboy Records did not have the distribution resources to meet the demand for the single from retailers and radio programmers. The foursome decided to record their first album together in late 1972, and sessions began on 26 September 1972. The women shared lead vocals on "Nina, Pretty Ballerina" (a top ten hit in Austria)""" EXPECTED_OUTPUTS = [
" albert einstein",
self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1) " september 22, 2017",
self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2) " amplitude modulation",
" stefan persson",
" april 20, 2018",
" the 1970s",
" 7.1. 2",
" 13",
" step by step",
" stomach",
" spodumene",
" obama",
" northern new jersey",
" india",
" united stadium",
]
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
@require_torch @require_torch
......
...@@ -166,7 +166,7 @@ class RagRetrieverTest(TestCase): ...@@ -166,7 +166,7 @@ class RagRetrieverTest(TestCase):
self.assertEqual(len(doc_dicts[0]["id"]), n_docs) self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0]) self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_legacy_index_retriever_retrieve(self): def test_legacy_index_retriever_retrieve(self):
n_docs = 1 n_docs = 1
...@@ -181,7 +181,7 @@ class RagRetrieverTest(TestCase): ...@@ -181,7 +181,7 @@ class RagRetrieverTest(TestCase):
self.assertEqual(len(doc_dicts[0]["text"]), n_docs) self.assertEqual(len(doc_dicts[0]["text"]), n_docs)
self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc self.assertEqual(doc_dicts[0]["text"][0], "bar") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
self.assertListEqual(list(doc_ids), [1, 0]) self.assertListEqual(doc_ids.tolist(), [[1], [0]])
@require_torch @require_torch
def test_hf_index_retriever_call(self): def test_hf_index_retriever_call(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment