Unverified Commit 2c8ecdf8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix rag retriever save pretrained (#7399)

parent 1a14687e
...@@ -312,8 +312,8 @@ class RagRetriever: ...@@ -312,8 +312,8 @@ class RagRetriever:
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
self.config.save_pretrained(save_directory) self.config.save_pretrained(save_directory)
rag_tokenizer = RagTokenizer( rag_tokenizer = RagTokenizer(
question_encoder_tokenizer=self.question_encoder_tokenizer, question_encoder=self.question_encoder_tokenizer,
generator_tokenizer=self.generator_tokenizer, generator=self.generator_tokenizer,
) )
rag_tokenizer.save_pretrained(save_directory) rag_tokenizer.save_pretrained(save_directory)
......
...@@ -168,6 +168,11 @@ class RagRetrieverTest(TestCase): ...@@ -168,6 +168,11 @@ class RagRetrieverTest(TestCase):
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]]) self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_save_and_from_pretrained(self):
retriever = self.get_dummy_hf_index_retriever()
with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname)
def test_legacy_index_retriever_retrieve(self): def test_legacy_index_retriever_retrieve(self):
n_docs = 1 n_docs = 1
retriever = self.get_dummy_legacy_index_retriever() retriever = self.get_dummy_legacy_index_retriever()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment