"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "d5256fe8d71a2a407e82cf75eb618c99fa63d524"
Unverified Commit 033f29c6 authored by Quentin Lhoest's avatar Quentin Lhoest Committed by GitHub
Browse files

Allow Custom Dataset in RAG Retriever (#7763)

* add CustomHFIndex

* typo in config

* update tests

* add custom dataset example

* clean script

* update test data

* minor in test

* docs

* docs

* style

* fix imports

* allow to pass the indexed dataset directly

* update tests

* use multiset DPR

* address thom and patrick's comments

* style

* update dpr tokenizer

* add output_dir flag in use_own_knowledge_dataset.py

* allow custom datasets in examples/rag/finetune.py

* add test for custom dataset in distributed rag retriever
parent a09fe140
...@@ -62,8 +62,7 @@ Rag specific outputs ...@@ -62,8 +62,7 @@ Rag specific outputs
.. autoclass:: transformers.modeling_rag.RetrievAugLMOutput .. autoclass:: transformers.modeling_rag.RetrievAugLMOutput
:members: :members:
RagRetriever
RAGRetriever
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RagRetriever .. autoclass:: transformers.RagRetriever
......
...@@ -27,13 +27,18 @@ class RagPyTorchDistributedRetriever(RagRetriever): ...@@ -27,13 +27,18 @@ class RagPyTorchDistributedRetriever(RagRetriever):
It is used to decode the question and then use the generator_tokenizer. It is used to decode the question and then use the generator_tokenizer.
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`): generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
The tokenizer used for the generator part of the RagModel. The tokenizer used for the generator part of the RagModel.
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
If specified, use this index instead of the one built using the configuration
""" """
_init_retrieval = False _init_retrieval = False
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer): def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
super().__init__( super().__init__(
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
) )
self.process_group = None self.process_group = None
......
...@@ -90,6 +90,11 @@ class GenerativeQAModule(BaseTransformer): ...@@ -90,6 +90,11 @@ class GenerativeQAModule(BaseTransformer):
config_class = RagConfig if self.is_rag_model else AutoConfig config_class = RagConfig if self.is_rag_model else AutoConfig
config = config_class.from_pretrained(hparams.model_name_or_path) config = config_class.from_pretrained(hparams.model_name_or_path)
# set retriever parameters
config.index_name = args.index_name or config.index_name
config.passages_path = args.passages_path or config.passages_path
config.index_path = args.index_path or config.index_path
# set extra_model_params for generator configs and load_model # set extra_model_params for generator configs and load_model
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout") extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
if self.is_rag_model: if self.is_rag_model:
...@@ -97,7 +102,7 @@ class GenerativeQAModule(BaseTransformer): ...@@ -97,7 +102,7 @@ class GenerativeQAModule(BaseTransformer):
config.generator.prefix = args.prefix config.generator.prefix = args.prefix
config.label_smoothing = hparams.label_smoothing config.label_smoothing = hparams.label_smoothing
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator) hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path) retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever) model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
prefix = config.question_encoder.prefix prefix = config.question_encoder.prefix
else: else:
...@@ -405,6 +410,28 @@ class GenerativeQAModule(BaseTransformer): ...@@ -405,6 +410,28 @@ class GenerativeQAModule(BaseTransformer):
) )
return parser return parser
@staticmethod
def add_retriever_specific_args(parser):
parser.add_argument(
"--index_name",
type=str,
default=None,
help="Name of the index to use: 'hf' for a canonical dataset from the datasets library (default), 'custom' for a local index, or 'legacy' for the orignal one)",
)
parser.add_argument(
"--passages_path",
type=str,
default=None,
help="Path to the dataset of passages for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
parser.add_argument(
"--index_path",
type=str,
default=None,
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
return parser
def main(args, model=None) -> GenerativeQAModule: def main(args, model=None) -> GenerativeQAModule:
Path(args.output_dir).mkdir(exist_ok=True) Path(args.output_dir).mkdir(exist_ok=True)
...@@ -465,6 +492,7 @@ if __name__ == "__main__": ...@@ -465,6 +492,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd()) parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
parser = GenerativeQAModule.add_retriever_specific_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
Aaron Aaron Aaron ( or ; "Ahärôn") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman ("prophet") to the Pharaoh. Part of the Law (Torah) that Moses received from God at Sinai granted Aaron the priesthood for himself and his male descendants, and he became the first High Priest of the Israelites. Aaron died before the Israelites crossed the North Jordan river and he was buried on Mount Hor (Numbers 33:39; Deuteronomy 10:6 says he died and was buried at Moserah). Aaron is also mentioned in the New Testament of the Bible. According to the Book of Exodus, Aaron first functioned as Moses' assistant. Because Moses complained that he could not speak well, God appointed Aaron as Moses' "prophet" (Exodus 4:10-17; 7:1). At the command of Moses, he let his rod turn into a snake. Then he stretched out his rod in order to bring on the first three plagues. After that, Moses tended to act and speak for himself. During the journey in the wilderness, Aaron was not always prominent or active. At the battle with Amalek, he was chosen with Hur to support the hand of Moses that held the "rod of God". When the revelation was given to Moses at biblical Mount Sinai, he headed the elders of Israel who accompanied Moses on the way to the summit.
"Pokémon" Pokémon , also known as in Japan, is a media franchise managed by The Pokémon Company, a Japanese consortium between Nintendo, Game Freak, and Creatures. The franchise copyright is shared by all three companies, but Nintendo is the sole owner of the trademark. The franchise was created by Satoshi Tajiri in 1995, and is centered on fictional creatures called "Pokémon", which humans, known as Pokémon Trainers, catch and train to battle each other for sport. The English slogan for the franchise is "Gotta Catch 'Em All". Works within the franchise are set in the Pokémon universe. The franchise began as "Pokémon Red" and "Green" (released outside of Japan as "Pokémon Red" and "Blue"), a pair of video games for the original Game Boy that were developed by Game Freak and published by Nintendo in February 1996. "Pokémon" has since gone on to become the highest-grossing media franchise of all time, with over in revenue up until March 2017. The original video game series is the second best-selling video game franchise (behind Nintendo's "Mario" franchise) with more than 300million copies sold and over 800million mobile downloads. In addition, the "Pokémon" franchise includes the world's top-selling toy brand, the top-selling trading card game with over 25.7billion cards sold, an anime television series that has become the most successful video game adaptation with over 20 seasons and 1,000 episodes in 124 countries, as well as an anime film series, a , books, manga comics, music, and merchandise. The franchise is also represented in other Nintendo media, such as the "Super Smash Bros." series. In November 2005, 4Kids Entertainment, which had managed the non-game related licensing of "Pokémon", announced that it had agreed not to renew the "Pokémon" representation agreement. The Pokémon Company International oversees all "Pokémon" licensing outside Asia.
\ No newline at end of file
...@@ -15,6 +15,7 @@ from transformers.configuration_bart import BartConfig ...@@ -15,6 +15,7 @@ from transformers.configuration_bart import BartConfig
from transformers.configuration_dpr import DPRConfig from transformers.configuration_dpr import DPRConfig
from transformers.configuration_rag import RagConfig from transformers.configuration_rag import RagConfig
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
from transformers.retrieval_rag import CustomHFIndex
from transformers.tokenization_bart import BartTokenizer from transformers.tokenization_bart import BartTokenizer
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
...@@ -114,7 +115,7 @@ class RagRetrieverTest(TestCase): ...@@ -114,7 +115,7 @@ class RagRetrieverTest(TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def get_dummy_pytorch_distributed_retriever(self, init_retrieval, port=12345) -> RagPyTorchDistributedRetriever: def get_dummy_dataset(self):
dataset = Dataset.from_dict( dataset = Dataset.from_dict(
{ {
"id": ["0", "1"], "id": ["0", "1"],
...@@ -124,6 +125,12 @@ class RagRetrieverTest(TestCase): ...@@ -124,6 +125,12 @@ class RagRetrieverTest(TestCase):
} }
) )
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
return dataset
def get_dummy_pytorch_distributed_retriever(
self, init_retrieval: bool, port=12345
) -> RagPyTorchDistributedRetriever:
dataset = self.get_dummy_dataset()
config = RagConfig( config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size, retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(), question_encoder=DPRConfig().to_dict(),
...@@ -140,6 +147,37 @@ class RagRetrieverTest(TestCase): ...@@ -140,6 +147,37 @@ class RagRetrieverTest(TestCase):
retriever.init_retrieval(port) retriever.init_retrieval(port)
return retriever return retriever
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
dataset = self.get_dummy_dataset()
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="custom",
)
if from_disk:
config.passages_path = os.path.join(self.tmpdirname, "dataset")
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
dataset.drop_index("embeddings")
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
del dataset
retriever = RagPyTorchDistributedRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
)
else:
retriever = RagPyTorchDistributedRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
index=CustomHFIndex(config.retrieval_vector_size, dataset),
)
if init_retrieval:
retriever.init_retrieval(port)
return retriever
def test_pytorch_distributed_retriever_retrieve(self): def test_pytorch_distributed_retriever_retrieve(self):
n_docs = 1 n_docs = 1
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True) retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
...@@ -154,3 +192,33 @@ class RagRetrieverTest(TestCase): ...@@ -154,3 +192,33 @@ class RagRetrieverTest(TestCase):
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]]) self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_custom_hf_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
n_docs = 1
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional
import torch
from datasets import load_dataset
import faiss
from transformers import (
DPRContextEncoder,
DPRContextEncoderTokenizerFast,
HfArgumentParser,
RagRetriever,
RagSequenceForGeneration,
RagTokenizer,
)
logger = logging.getLogger(__name__)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
def split_text(text: str, n=100, character=" ") -> List[str]:
"""Split the text every ``n``-th occurence of ``character``"""
text = text.split(character)
return [character.join(text[i : i + n]).strip() for i in range(0, len(text), n)]
def split_documents(documents: dict) -> dict:
"""Split documents into passages"""
titles, texts = [], []
for title, text in zip(documents["title"], documents["text"]):
for passage in split_text(text):
titles.append(title)
texts.append(passage)
return {"title": titles, "text": texts}
def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast) -> dict:
"""Compute the DPR embeddings of document passages"""
input_ids = ctx_tokenizer(
documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt"
)["input_ids"]
embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
return {"embeddings": embeddings.detach().cpu().numpy()}
def main(
rag_example_args: "RagExampleArguments",
processing_args: "ProcessingArguments",
index_hnsw_args: "IndexHnswArguments",
):
######################################
logger.info("Step 1 - Create the dataset")
######################################
# The dataset needed for RAG must have three columns:
# - title (string): title of the document
# - text (string): text of a passage of the document
# - embeddings (array of dimension d): DPR representation of the passage
# Let's say you have documents in tab-separated csv files with columns "title" and "text"
assert os.path.isfile(rag_example_args.csv_path), "Please provide a valid path to a csv file"
# You can load a Dataset object this way
dataset = load_dataset(
"csv", data_files=[rag_example_args.csv_path], split="train", delimiter="\t", column_names=["title", "text"]
)
# More info about loading csv files in the documentation: https://huggingface.co/docs/datasets/loading_datasets.html?highlight=csv#csv-files
# Then split the documents into passages of 100 words
dataset = dataset.map(split_documents, batched=True, num_proc=processing_args.num_proc)
# And compute the embeddings
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
dataset = dataset.map(
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
batched=True,
batch_size=processing_args.batch_size,
)
# And finally save your dataset
passages_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset")
dataset.save_to_disk(passages_path)
# from datasets import load_from_disk
# dataset = load_from_disk(passages_path) # to reload the dataset
######################################
logger.info("Step 2 - Index the dataset")
######################################
# Let's use the Faiss implementation of HNSW for fast approximate nearest neighbor search
index = faiss.IndexHNSWFlat(index_hnsw_args.d, index_hnsw_args.m, faiss.METRIC_INNER_PRODUCT)
dataset.add_faiss_index("embeddings", custom_index=index)
# And save the index
index_path = os.path.join(rag_example_args.output_dir, "my_knowledge_dataset_hnsw_index.faiss")
dataset.get_index("embeddings").save(index_path)
# dataset.load_faiss_index("embeddings", index_path) # to reload the index
######################################
logger.info("Step 3 - Load RAG")
######################################
# Easy way to load the model
retriever = RagRetriever.from_pretrained(
rag_example_args.rag_model_name, index_name="custom", indexed_dataset=dataset
)
model = RagSequenceForGeneration.from_pretrained(rag_example_args.rag_model_name, retriever=retriever)
tokenizer = RagTokenizer.from_pretrained(rag_example_args.rag_model_name)
# For distributed fine-tuning you'll need to provide the paths instead, as the dataset and the index are loaded separately.
# retriever = RagRetriever.from_pretrained(rag_model_name, index_name="custom", passages_path=passages_path, index_path=index_path)
######################################
logger.info("Step 4 - Have fun")
######################################
question = rag_example_args.question or "What does Moses' rod turn into ?"
input_ids = tokenizer.question_encoder(question, return_tensors="pt")["input_ids"]
generated = model.generate(input_ids)
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
logger.info("Q: " + question)
logger.info("A: " + generated_string)
@dataclass
class RagExampleArguments:
csv_path: str = field(
default=str(Path(__file__).parent / "test_data" / "my_knowledge_dataset.csv"),
metadata={"help": "Path to a tab-separated csv file with columns 'title' and 'text'"},
)
question: Optional[str] = field(
default=None,
metadata={"help": "Question that is passed as input to RAG. Default is 'What does Moses' rod turn into ?'."},
)
rag_model_name: str = field(
default="facebook/rag-sequence-nq",
metadata={"help": "The RAG model to use. Either 'facebook/rag-sequence-nq' or 'facebook/rag-token-nq'"},
)
dpr_ctx_encoder_model_name: str = field(
default="facebook/dpr-ctx_encoder-multiset-base",
metadata={
"help": "The DPR context encoder model to use. Either 'facebook/dpr-ctx_encoder-single-nq-base' or 'facebook/dpr-ctx_encoder-multiset-base'"
},
)
output_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to a directory where the dataset passages and the index will be saved"},
)
@dataclass
class ProcessingArguments:
num_proc: Optional[int] = field(
default=None,
metadata={
"help": "The number of processes to use to split the documents into passages. Default is single process."
},
)
batch_size: int = field(
default=16,
metadata={
"help": "The batch size to use when computing the passages embeddings using the DPR context encoder."
},
)
@dataclass
class IndexHnswArguments:
d: int = field(
default=768,
metadata={"help": "The dimension of the embeddings to pass to the HNSW Faiss index."},
)
m: int = field(
default=128,
metadata={
"help": "The number of bi-directional links created for every new element during the HNSW index construction."
},
)
if __name__ == "__main__":
logging.basicConfig(level=logging.WARNING)
logger.setLevel(logging.INFO)
parser = HfArgumentParser((RagExampleArguments, ProcessingArguments, IndexHnswArguments))
rag_example_args, processing_args, index_hnsw_args = parser.parse_args_into_dataclasses()
with TemporaryDirectory() as tmp_dir:
rag_example_args.output_dir = rag_example_args.output_dir or tmp_dir
main(rag_example_args, processing_args, index_hnsw_args)
...@@ -24,6 +24,9 @@ DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = { ...@@ -24,6 +24,9 @@ DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json", "facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json",
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-question_encoder-single-nq-base/config.json", "facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-question_encoder-single-nq-base/config.json",
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-reader-single-nq-base/config.json", "facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-reader-single-nq-base/config.json",
"facebook/dpr-ctx_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-multiset-base/config.json",
"facebook/dpr-question_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-question_encoder-multiset-base/config.json",
"facebook/dpr-reader-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-reader-multiset-base/config.json",
} }
......
...@@ -41,7 +41,7 @@ RAG_CONFIG_DOC = r""" ...@@ -41,7 +41,7 @@ RAG_CONFIG_DOC = r"""
Retrieval batch size, defined as the number of queries issues concurrently to the faiss index excapsulated Retrieval batch size, defined as the number of queries issues concurrently to the faiss index excapsulated
:class:`~transformers.RagRetriever`. :class:`~transformers.RagRetriever`.
dataset (:obj:`str`, `optional`, defaults to :obj:`"wiki_dpr"`): dataset (:obj:`str`, `optional`, defaults to :obj:`"wiki_dpr"`):
A dataset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and A dataset identifier of the indexed dataset in HuggingFace Datasets (list all available datasets and
ids using :obj:`datasets.list_datasets()`). ids using :obj:`datasets.list_datasets()`).
dataset_split (:obj:`str`, `optional`, defaults to :obj:`"train"`) dataset_split (:obj:`str`, `optional`, defaults to :obj:`"train"`)
Which split of the :obj:`dataset` to load. Which split of the :obj:`dataset` to load.
......
...@@ -35,12 +35,15 @@ _CONFIG_FOR_DOC = "DPRConfig" ...@@ -35,12 +35,15 @@ _CONFIG_FOR_DOC = "DPRConfig"
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [ DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/dpr-ctx_encoder-single-nq-base", "facebook/dpr-ctx_encoder-single-nq-base",
"facebook/dpr-ctx_encoder-multiset-base",
] ]
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [ DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/dpr-question_encoder-single-nq-base", "facebook/dpr-question_encoder-single-nq-base",
"facebook/dpr-question_encoder-multiset-base",
] ]
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST = [ DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/dpr-reader-single-nq-base", "facebook/dpr-reader-single-nq-base",
"facebook/dpr-reader-multiset-base",
] ]
......
...@@ -36,7 +36,7 @@ from .utils import logging ...@@ -36,7 +36,7 @@ from .utils import logging
if is_datasets_available(): if is_datasets_available():
from datasets import load_dataset from datasets import Dataset, load_dataset, load_from_disk
if is_faiss_available(): if is_faiss_available():
import faiss import faiss
...@@ -93,7 +93,7 @@ class Index: ...@@ -93,7 +93,7 @@ class Index:
raise NotImplementedError raise NotImplementedError
class LegacyIndex: class LegacyIndex(Index):
""" """
An index which can be deserialized from the files built using https://github.com/facebookresearch/DPR. An index which can be deserialized from the files built using https://github.com/facebookresearch/DPR.
We use default faiss index parameters as specified in that repository. We use default faiss index parameters as specified in that repository.
...@@ -115,7 +115,7 @@ class LegacyIndex: ...@@ -115,7 +115,7 @@ class LegacyIndex:
self.passages = self._load_passages() self.passages = self._load_passages()
self.vector_size = vector_size self.vector_size = vector_size
self.index = None self.index = None
self._index_initialize = False self._index_initialized = False
def _resolve_path(self, index_path, filename): def _resolve_path(self, index_path, filename):
assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid ``index_path``." assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid ``index_path``."
...@@ -157,7 +157,7 @@ class LegacyIndex: ...@@ -157,7 +157,7 @@ class LegacyIndex:
), "Deserialized index_id_to_db_id should match faiss index size" ), "Deserialized index_id_to_db_id should match faiss index size"
def is_initialized(self): def is_initialized(self):
return self._index_initialize return self._index_initialized
def init_index(self): def init_index(self):
index = faiss.IndexHNSWFlat(self.vector_size + 1, 512) index = faiss.IndexHNSWFlat(self.vector_size + 1, 512)
...@@ -165,7 +165,7 @@ class LegacyIndex: ...@@ -165,7 +165,7 @@ class LegacyIndex:
index.hnsw.efConstruction = 200 index.hnsw.efConstruction = 200
self.index = index self.index = index
self._deserialize_index() self._deserialize_index()
self._index_initialize = True self._index_initialized = True
def get_doc_dicts(self, doc_ids: np.array): def get_doc_dicts(self, doc_ids: np.array):
doc_list = [] doc_list = []
...@@ -190,13 +190,56 @@ class LegacyIndex: ...@@ -190,13 +190,56 @@ class LegacyIndex:
return np.array(ids), np.array(vectors) return np.array(ids), np.array(vectors)
class HFIndex: class HFIndexBase(Index):
def __init__(self, vector_size, dataset, index_initialized=False):
self.vector_size = vector_size
self.dataset = dataset
self._index_initialized = index_initialized
self._check_dataset_format(with_index=index_initialized)
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
def _check_dataset_format(self, with_index: bool):
if not isinstance(self.dataset, Dataset):
raise ValueError("Dataset should be a datasets.Dataset object, but got {}".format(type(self.dataset)))
if len({"title", "text", "embeddings"} - set(self.dataset.column_names)) > 0:
raise ValueError(
"Dataset should be a dataset with the following columns: "
"title (str), text (str) and embeddings (arrays of dimension vector_size), "
"but got columns {}".format(self.dataset.column_names)
)
if with_index and "embeddings" not in self.dataset.list_indexes():
raise ValueError(
"Missing faiss index in the dataset. Make sure you called `dataset.add_faiss_index` to compute it "
"or `dataset.load_faiss_index` to load one from the disk."
)
def init_index(self):
raise NotImplementedError()
def is_initialized(self):
return self._index_initialized
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
_, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs)
docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids]
vectors = [doc["embeddings"] for doc in docs]
for i in range(len(vectors)):
if len(vectors[i]) < n_docs:
vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))])
return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
class CanonicalHFIndex(HFIndexBase):
""" """
A wrapper around an instance of :class:`~datasets.Datasets`. If ``index_path`` is set to ``None``, A wrapper around an instance of :class:`~datasets.Datasets`. If ``index_path`` is set to ``None``,
we load the pre-computed index available with the :class:`~datasets.arrow_dataset.Dataset`, otherwise, we load the index from the indicated path on disk. we load the pre-computed index available with the :class:`~datasets.arrow_dataset.Dataset`, otherwise, we load the index from the indicated path on disk.
Args: Args:
dataset (:obj:`str`, optional, defaults to ``wiki_dpr``): vector_size (:obj:`int`): the dimension of the passages embeddings used by the index
dataset_name (:obj:`str`, optional, defaults to ``wiki_dpr``):
A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids with ``datasets.list_datasets()``). A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids with ``datasets.list_datasets()``).
dataset_split (:obj:`str`, optional, defaults to ``train``) dataset_split (:obj:`str`, optional, defaults to ``train``)
Which split of the ``dataset`` to load. Which split of the ``dataset`` to load.
...@@ -204,39 +247,35 @@ class HFIndex: ...@@ -204,39 +247,35 @@ class HFIndex:
The index_name of the index associated with the ``dataset``. The index loaded from ``index_path`` will be saved under this name. The index_name of the index associated with the ``dataset``. The index loaded from ``index_path`` will be saved under this name.
index_path (:obj:`str`, optional, defaults to ``None``) index_path (:obj:`str`, optional, defaults to ``None``)
The path to the serialized faiss index on disk. The path to the serialized faiss index on disk.
use_dummy_dataset (:obj:`bool`, optional, defaults to ``False``): If True, use the dummy configuration of the dataset for tests.
""" """
def __init__( def __init__(
self, self,
dataset_name: str,
dataset_split: str,
index_name: str,
vector_size: int, vector_size: int,
dataset_name: str = "wiki_dpr",
dataset_split: str = "train",
index_name: Optional[str] = None,
index_path: Optional[str] = None, index_path: Optional[str] = None,
use_dummy_dataset=False, use_dummy_dataset=False,
): ):
super().__init__() if int(index_path is None) + int(index_name is None) != 1:
raise ValueError("Please provide `index_name` or `index_path`.")
self.dataset_name = dataset_name self.dataset_name = dataset_name
self.dataset_split = dataset_split self.dataset_split = dataset_split
self.index_name = index_name self.index_name = index_name
self.vector_size = vector_size
self.index_path = index_path self.index_path = index_path
self.use_dummy_dataset = use_dummy_dataset self.use_dummy_dataset = use_dummy_dataset
self._index_initialize = False
logger.info("Loading passages from {}".format(self.dataset_name)) logger.info("Loading passages from {}".format(self.dataset_name))
self.dataset = load_dataset( dataset = load_dataset(
self.dataset_name, with_index=False, split=self.dataset_split, dummy=self.use_dummy_dataset self.dataset_name, with_index=False, split=self.dataset_split, dummy=self.use_dummy_dataset
) )
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True) super().__init__(vector_size, dataset, index_initialized=False)
def is_initialized(self):
return self._index_initialize
def init_index(self): def init_index(self):
if self.index_path is not None: if self.index_path is not None:
logger.info("Loading index from {}".format(self.index_path)) logger.info("Loading index from {}".format(self.index_path))
self.index.load_faiss_index(index_name=self.index_name, file=self.index_path) self.dataset.load_faiss_index("embeddings", file=self.index_path)
else: else:
logger.info("Loading index from {}".format(self.dataset_name + " with index name " + self.index_name)) logger.info("Loading index from {}".format(self.dataset_name + " with index name " + self.index_name))
self.dataset = load_dataset( self.dataset = load_dataset(
...@@ -248,19 +287,43 @@ class HFIndex: ...@@ -248,19 +287,43 @@ class HFIndex:
dummy=self.use_dummy_dataset, dummy=self.use_dummy_dataset,
) )
self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True) self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
self._index_initialize = True self._index_initialized = True
def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]: class CustomHFIndex(HFIndexBase):
_, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs) """
docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids] A wrapper around an instance of :class:`~datasets.Datasets`.
vectors = [doc["embeddings"] for doc in docs] The dataset and the index are both loaded from the indicated paths on disk.
for i in range(len(vectors)):
if len(vectors[i]) < n_docs: Args:
vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))]) vector_size (:obj:`int`): the dimension of the passages embeddings used by the index
return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d) dataset_path (:obj:`str`):
The path to the serialized dataset on disk.
The dataset should have 3 columns: title (str), text (str) and embeddings (arrays of dimension vector_size)
index_path (:obj:`str`)
The path to the serialized faiss index on disk.
"""
def __init__(self, vector_size: int, dataset, index_path=None):
super().__init__(vector_size, dataset, index_initialized=index_path is None)
self.index_path = index_path
@classmethod
def load_from_disk(cls, vector_size, dataset_path, index_path):
logger.info("Loading passages from {}".format(dataset_path))
if dataset_path is None or index_path is None:
raise ValueError(
"Please provide ``dataset_path`` and ``index_path`` after calling ``dataset.save_to_disk(dataset_path)`` "
"and ``dataset.get_index('embeddings').save(index_path)``."
)
dataset = load_from_disk(dataset_path)
return cls(vector_size=vector_size, dataset=dataset, index_path=index_path)
def init_index(self):
if not self.is_initialized():
logger.info("Loading index from {}".format(self.index_path))
self.dataset.load_faiss_index("embeddings", file=self.index_path)
self._index_initialized = True
class RagRetriever: class RagRetriever:
...@@ -271,34 +334,46 @@ class RagRetriever: ...@@ -271,34 +334,46 @@ class RagRetriever:
Args: Args:
config (:class:`~transformers.RagConfig`): config (:class:`~transformers.RagConfig`):
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build. The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
You can load your own custom dataset with ``config.index_name="custom"`` or use a canonical one (default) from the datasets library
with ``config.index_name="wiki_dpr"`` for example.
question_encoder_tokenizer (:class:`~transformers.PreTrainedTokenizer`): question_encoder_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer that was used to tokenize the question. The tokenizer that was used to tokenize the question.
It is used to decode the question and then use the generator_tokenizer. It is used to decode the question and then use the generator_tokenizer.
generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`): generator_tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer used for the generator part of the RagModel. The tokenizer used for the generator part of the RagModel.
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
If specified, use this index instead of the one built using the configuration
Examples::
>>> # To load the default "wiki_dpr" dataset with 21M passages from wikipedia (index name is 'compressed' or 'exact')
>>> from transformers import RagRetriever
>>> retriever = RagRetriever.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', dataset="wiki_dpr", index_name='compressed')
>>> # To load your own indexed dataset built with the datasets library. More info on how to build the indexed dataset in examples/rag/use_own_knowledge_dataset.py
>>> from transformers import RagRetriever
>>> dataset = ... # dataset must be a datasets.Datasets object with columns "title", "text" and "embeddings", and it must have a faiss index
>>> retriever = RagRetriever.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', indexed_dataset=dataset)
>>> # To load your own indexed dataset built with the datasets library that was saved on disk. More info in examples/rag/use_own_knowledge_dataset.py
>>> from transformers import RagRetriever
>>> dataset_path = "path/to/my/dataset" # dataset saved via `dataset.save_to_disk(...)`
>>> index_path = "path/to/my/index.faiss" # faiss index saved via `dataset.get_index("embeddings").save(...)`
>>> retriever = RagRetriever.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', index_name='custom', passages_path=dataset_path, index_path=index_path)
>>> # To load the legacy index built originally for Rag's paper
>>> from transformers import RagRetriever
>>> retriever = RagRetriever.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', index_name='legacy')
""" """
_init_retrieval = True _init_retrieval = True
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer): def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
requires_datasets(self) requires_datasets(self)
requires_faiss(self) requires_faiss(self)
super().__init__() super().__init__()
self.index = ( self.index = index or self._build_index(config)
LegacyIndex(
config.retrieval_vector_size,
config.index_path or LEGACY_INDEX_PATH,
)
if config.index_name == "legacy"
else HFIndex(
config.dataset,
config.dataset_split,
config.index_name,
config.retrieval_vector_size,
config.index_path,
config.use_dummy_dataset,
)
)
self.generator_tokenizer = generator_tokenizer self.generator_tokenizer = generator_tokenizer
self.question_encoder_tokenizer = question_encoder_tokenizer self.question_encoder_tokenizer = question_encoder_tokenizer
...@@ -309,19 +384,62 @@ class RagRetriever: ...@@ -309,19 +384,62 @@ class RagRetriever:
if self._init_retrieval: if self._init_retrieval:
self.init_retrieval() self.init_retrieval()
@staticmethod
def _build_index(config):
if config.index_name == "legacy":
return LegacyIndex(
config.retrieval_vector_size,
config.index_path or LEGACY_INDEX_PATH,
)
elif config.index_name == "custom":
return CustomHFIndex.load_from_disk(
vector_size=config.retrieval_vector_size,
dataset_path=config.passages_path,
index_path=config.index_path,
)
else:
return CanonicalHFIndex(
vector_size=config.retrieval_vector_size,
dataset_name=config.dataset,
dataset_split=config.dataset_split,
index_name=config.index_name,
index_path=config.index_path,
use_dummy_dataset=config.use_dummy_dataset,
)
@classmethod @classmethod
def from_pretrained(cls, retriever_name_or_path, **kwargs): def from_pretrained(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
requires_datasets(cls) requires_datasets(cls)
requires_faiss(cls) requires_faiss(cls)
config = RagConfig.from_pretrained(retriever_name_or_path, **kwargs) config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config) rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder question_encoder_tokenizer = rag_tokenizer.question_encoder
generator_tokenizer = rag_tokenizer.generator generator_tokenizer = rag_tokenizer.generator
if indexed_dataset is not None:
config.index_name = "custom"
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
else:
index = cls._build_index(config)
return cls( return cls(
config, question_encoder_tokenizer=question_encoder_tokenizer, generator_tokenizer=generator_tokenizer config,
question_encoder_tokenizer=question_encoder_tokenizer,
generator_tokenizer=generator_tokenizer,
index=index,
) )
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
if isinstance(self.index, CustomHFIndex):
if self.config.index_path is None:
index_path = os.path.join(save_directory, "hf_dataset_index.faiss")
self.index.dataset.get_index("embeddings").save(index_path)
self.config.index_path = index_path
if self.config.passages_path is None:
passages_path = os.path.join(save_directory, "hf_dataset")
# datasets don't support save_to_disk with indexes right now
faiss_index = self.index.dataset._indexes.pop("embeddings")
self.index.dataset.save_to_disk(passages_path)
self.index.dataset._indexes["embeddings"] = faiss_index
self.config.passages_path = passages_path
self.config.save_pretrained(save_directory) self.config.save_pretrained(save_directory)
rag_tokenizer = RagTokenizer( rag_tokenizer = RagTokenizer(
question_encoder=self.question_encoder_tokenizer, question_encoder=self.question_encoder_tokenizer,
......
...@@ -26,43 +26,64 @@ from .utils import logging ...@@ -26,43 +26,64 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = { CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", "facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
} "facebook/dpr-ctx_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
},
"tokenizer_file": {
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
"facebook/dpr-ctx_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
},
} }
QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = { QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", "facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
} "facebook/dpr-question_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
},
"tokenizer_file": {
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
"facebook/dpr-question_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
},
} }
READER_PRETRAINED_VOCAB_FILES_MAP = { READER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", "facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
} "facebook/dpr-reader-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
},
"tokenizer_file": {
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
"facebook/dpr-reader-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
},
} }
CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/dpr-ctx_encoder-single-nq-base": 512, "facebook/dpr-ctx_encoder-single-nq-base": 512,
"facebook/dpr-ctx_encoder-multiset-base": 512,
} }
QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/dpr-question_encoder-single-nq-base": 512, "facebook/dpr-question_encoder-single-nq-base": 512,
"facebook/dpr-question_encoder-multiset-base": 512,
} }
READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/dpr-reader-single-nq-base": 512, "facebook/dpr-reader-single-nq-base": 512,
"facebook/dpr-reader-multiset-base": 512,
} }
CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = { CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
"facebook/dpr-ctx_encoder-single-nq-base": {"do_lower_case": True}, "facebook/dpr-ctx_encoder-single-nq-base": {"do_lower_case": True},
"facebook/dpr-ctx_encoder-multiset-base": {"do_lower_case": True},
} }
QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = { QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
"facebook/dpr-question_encoder-single-nq-base": {"do_lower_case": True}, "facebook/dpr-question_encoder-single-nq-base": {"do_lower_case": True},
"facebook/dpr-question_encoder-multiset-base": {"do_lower_case": True},
} }
READER_PRETRAINED_INIT_CONFIGURATION = { READER_PRETRAINED_INIT_CONFIGURATION = {
"facebook/dpr-reader-single-nq-base": {"do_lower_case": True}, "facebook/dpr-reader-single-nq-base": {"do_lower_case": True},
"facebook/dpr-reader-multiset-base": {"do_lower_case": True},
} }
......
...@@ -32,47 +32,59 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso ...@@ -32,47 +32,59 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.jso
CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = { CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", "facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
"facebook/dpr-ctx_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
}, },
"tokenizer_file": { "tokenizer_file": {
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json", "facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
"facebook/dpr-ctx_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
}, },
} }
QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = { QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", "facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
"facebook/dpr-question_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
}, },
"tokenizer_file": { "tokenizer_file": {
"facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json", "facebook/dpr-question_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
"facebook/dpr-question_encoder-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
}, },
} }
READER_PRETRAINED_VOCAB_FILES_MAP = { READER_PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": { "vocab_file": {
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", "facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
"facebook/dpr-reader-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
}, },
"tokenizer_file": { "tokenizer_file": {
"facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json", "facebook/dpr-reader-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
"facebook/dpr-reader-multiset-base": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tokenizer.json",
}, },
} }
CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/dpr-ctx_encoder-single-nq-base": 512, "facebook/dpr-ctx_encoder-single-nq-base": 512,
"facebook/dpr-ctx_encoder-multiset-base": 512,
} }
QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/dpr-question_encoder-single-nq-base": 512, "facebook/dpr-question_encoder-single-nq-base": 512,
"facebook/dpr-question_encoder-multiset-base": 512,
} }
READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { READER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/dpr-reader-single-nq-base": 512, "facebook/dpr-reader-single-nq-base": 512,
"facebook/dpr-reader-multiset-base": 512,
} }
CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = { CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
"facebook/dpr-ctx_encoder-single-nq-base": {"do_lower_case": True}, "facebook/dpr-ctx_encoder-single-nq-base": {"do_lower_case": True},
"facebook/dpr-ctx_encoder-multiset-base": {"do_lower_case": True},
} }
QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = { QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
"facebook/dpr-question_encoder-single-nq-base": {"do_lower_case": True}, "facebook/dpr-question_encoder-single-nq-base": {"do_lower_case": True},
"facebook/dpr-question_encoder-multiset-base": {"do_lower_case": True},
} }
READER_PRETRAINED_INIT_CONFIGURATION = { READER_PRETRAINED_INIT_CONFIGURATION = {
"facebook/dpr-reader-single-nq-base": {"do_lower_case": True}, "facebook/dpr-reader-single-nq-base": {"do_lower_case": True},
"facebook/dpr-reader-multiset-base": {"do_lower_case": True},
} }
......
...@@ -13,7 +13,7 @@ import faiss ...@@ -13,7 +13,7 @@ import faiss
from transformers.configuration_bart import BartConfig from transformers.configuration_bart import BartConfig
from transformers.configuration_dpr import DPRConfig from transformers.configuration_dpr import DPRConfig
from transformers.configuration_rag import RagConfig from transformers.configuration_rag import RagConfig
from transformers.retrieval_rag import RagRetriever from transformers.retrieval_rag import CustomHFIndex, RagRetriever
from transformers.testing_utils import ( from transformers.testing_utils import (
require_datasets, require_datasets,
require_faiss, require_faiss,
...@@ -103,7 +103,7 @@ class RagRetrieverTest(TestCase): ...@@ -103,7 +103,7 @@ class RagRetrieverTest(TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def get_dummy_hf_index_retriever(self): def get_dummy_dataset(self):
dataset = Dataset.from_dict( dataset = Dataset.from_dict(
{ {
"id": ["0", "1"], "id": ["0", "1"],
...@@ -113,6 +113,10 @@ class RagRetrieverTest(TestCase): ...@@ -113,6 +113,10 @@ class RagRetrieverTest(TestCase):
} }
) )
dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT) dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
return dataset
def get_dummy_canonical_hf_index_retriever(self):
dataset = self.get_dummy_dataset()
config = RagConfig( config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size, retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(), question_encoder=DPRConfig().to_dict(),
...@@ -127,6 +131,35 @@ class RagRetrieverTest(TestCase): ...@@ -127,6 +131,35 @@ class RagRetrieverTest(TestCase):
) )
return retriever return retriever
def get_dummy_custom_hf_index_retriever(self, from_disk: bool):
dataset = self.get_dummy_dataset()
config = RagConfig(
retrieval_vector_size=self.retrieval_vector_size,
question_encoder=DPRConfig().to_dict(),
generator=BartConfig().to_dict(),
index_name="custom",
)
if from_disk:
config.passages_path = os.path.join(self.tmpdirname, "dataset")
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
dataset.drop_index("embeddings")
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
del dataset
retriever = RagRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
)
else:
retriever = RagRetriever(
config,
question_encoder_tokenizer=self.get_dpr_tokenizer(),
generator_tokenizer=self.get_bart_tokenizer(),
index=CustomHFIndex(config.retrieval_vector_size, dataset),
)
return retriever
def get_dummy_legacy_index_retriever(self): def get_dummy_legacy_index_retriever(self):
dataset = Dataset.from_dict( dataset = Dataset.from_dict(
{ {
...@@ -152,16 +185,71 @@ class RagRetrieverTest(TestCase): ...@@ -152,16 +185,71 @@ class RagRetrieverTest(TestCase):
generator=BartConfig().to_dict(), generator=BartConfig().to_dict(),
index_name="legacy", index_name="legacy",
index_path=self.tmpdirname, index_path=self.tmpdirname,
passages_path=self.tmpdirname,
) )
retriever = RagRetriever( retriever = RagRetriever(
config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer() config, question_encoder_tokenizer=self.get_dpr_tokenizer(), generator_tokenizer=self.get_bart_tokenizer()
) )
return retriever return retriever
def test_hf_index_retriever_retrieve(self): def test_canonical_hf_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_canonical_hf_index_retriever()
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_canonical_hf_index_retriever_save_and_from_pretrained(self):
retriever = self.get_dummy_canonical_hf_index_retriever()
with tempfile.TemporaryDirectory() as tmp_dirname:
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
mock_load_dataset.return_value = self.get_dummy_dataset()
retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)
def test_custom_hf_index_retriever_retrieve(self):
n_docs = 1
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
self.assertEqual(len(doc_dicts), 2)
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_custom_hf_index_retriever_save_and_from_pretrained(self):
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)
def test_custom_hf_index_retriever_retrieve_from_disk(self):
n_docs = 1 n_docs = 1
retriever = self.get_dummy_hf_index_retriever() retriever = self.get_dummy_custom_hf_index_retriever(from_disk=True)
hidden_states = np.array( hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
) )
...@@ -174,10 +262,17 @@ class RagRetrieverTest(TestCase): ...@@ -174,10 +262,17 @@ class RagRetrieverTest(TestCase):
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]]) self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_save_and_from_pretrained(self): def test_custom_hf_index_retriever_save_and_from_pretrained_from_disk(self):
retriever = self.get_dummy_hf_index_retriever() retriever = self.get_dummy_custom_hf_index_retriever(from_disk=True)
with tempfile.TemporaryDirectory() as tmp_dirname: with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname) retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)
def test_legacy_index_retriever_retrieve(self): def test_legacy_index_retriever_retrieve(self):
n_docs = 1 n_docs = 1
...@@ -194,6 +289,18 @@ class RagRetrieverTest(TestCase): ...@@ -194,6 +289,18 @@ class RagRetrieverTest(TestCase):
self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc self.assertEqual(doc_dicts[1]["text"][0], "foo") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]]) self.assertListEqual(doc_ids.tolist(), [[1], [0]])
def test_legacy_hf_index_retriever_save_and_from_pretrained(self):
retriever = self.get_dummy_legacy_index_retriever()
with tempfile.TemporaryDirectory() as tmp_dirname:
retriever.save_pretrained(tmp_dirname)
retriever = RagRetriever.from_pretrained(tmp_dirname)
self.assertIsInstance(retriever, RagRetriever)
hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
)
out = retriever.retrieve(hidden_states, n_docs=1)
self.assertTrue(out is not None)
@require_torch @require_torch
@require_tokenizers @require_tokenizers
@require_sentencepiece @require_sentencepiece
...@@ -201,7 +308,7 @@ class RagRetrieverTest(TestCase): ...@@ -201,7 +308,7 @@ class RagRetrieverTest(TestCase):
import torch import torch
n_docs = 1 n_docs = 1
retriever = self.get_dummy_hf_index_retriever() retriever = self.get_dummy_canonical_hf_index_retriever()
question_input_ids = [[5, 7], [10, 11]] question_input_ids = [[5, 7], [10, 11]]
hidden_states = np.array( hidden_states = np.array(
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32 [np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment