Retrieval-augmented generation ("RAG") models combine the powers of pretrained dense retrieval (DPR) and Seq2Seq models.
Retrieval-augmented generation ("RAG") models combine the powers of pretrained dense retrieval (DPR) and
RAG models retrieve docs, pass them to a seq2seq model, then marginalize to generate outputs.
sequence-to-sequence models. RAG models retrieve documents, pass them to a seq2seq model, then marginalize to generate
The retriever and seq2seq modules are initialized from pretrained models, and fine-tuned jointly, allowing both retrieval and generation to adapt to downstream tasks.
outputs. The retriever and seq2seq modules are initialized from pretrained models, and fine-tuned jointly, allowing
both retrieval and generation to adapt to downstream tasks.
It is based on the paper `Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks <https://arxiv.org/abs/2005.11401>`__ by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
It is based on the paper `Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
<https://arxiv.org/abs/2005.11401>`__ by Patrick Lewis, Ethan Perez, Aleksandara Piktus, Fabio Petroni, Vladimir
Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, Sebastian Riedel, Douwe Kiela.
retrieval_vector_size (:obj:`int`, `optional`, defaults to 768):
retrieval_vector_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the document embeddings indexed by :class:`~transformers.RagRetriever`.
Dimensionality of the document embeddings indexed by :class:`~transformers.RagRetriever`.
retrieval_batch_size (:obj:`int`, `optional`, defaults to 8):
retrieval_batch_size (:obj:`int`, `optional`, defaults to 8):
Retrieval batch size, defined as the number of queries issues concurrently to the faiss index excapsulated :class:`~transformers.RagRetriever`.
Retrieval batch size, defined as the number of queries issues concurrently to the faiss index excapsulated
:class:`~transformers.RagRetriever`.
dataset (:obj:`str`, `optional`, defaults to :obj:`"wiki_dpr"`):
dataset (:obj:`str`, `optional`, defaults to :obj:`"wiki_dpr"`):
A datatset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids using :obj:`datasets.list_datasets()`).
A dataset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and
dataset_split (:obj:`str`, `optional`, defaults to :obj:`train`)
ids using :obj:`datasets.list_datasets()`).
Which split of the ``dataset`` to load.
dataset_split (:obj:`str`, `optional`, defaults to :obj:`"train"`)
index_name (:obj:`str`, `optional`, defaults to :obj:`compressed`)
Which split of the :obj:`dataset` to load.
The index_name of the index associated with the :obj:`dataset`. One can choose between :obj:`legacy`, :obj:`exact` and :obj:`compressed`.
index_name (:obj:`str`, `optional`, defaults to :obj:`"compressed"`)
The index name of the index associated with the :obj:`dataset`. One can choose between :obj:`"legacy"`,
:obj:`"exact"` and :obj:`"compressed"`.
index_path (:obj:`str`, `optional`)
index_path (:obj:`str`, `optional`)
The path to the serialized faiss index on disk.
The path to the serialized faiss index on disk.
passages_path: (:obj:`str`, `optional`):
passages_path: (:obj:`str`, `optional`):
A path to text passages compatible with the faiss index. Required if using :class:`~transformers.retrieval_rag.LegacyIndex`
A path to text passages compatible with the faiss index. Required if using
:class:`~transformers.retrieval_rag.LegacyIndex`
use_dummy_dataset (:obj:`bool`, `optional`, defaults to ``False``)
use_dummy_dataset (:obj:`bool`, `optional`, defaults to ``False``)
Whether to load a "dummy" variant of the dataset specified by :obj:`dataset`.
Whether to load a "dummy" variant of the dataset specified by :obj:`dataset`.
label_smoothing (:obj:`float`, `optional`, defaults to 0.0):
label_smoothing (:obj:`float`, `optional`, defaults to 0.0):
Only relevant if ``return_loss`` is set to :obj:`True`. Controls the ``epsilon`` parameter value for label smoothing in the loss calculation.
Only relevant if ``return_loss`` is set to :obj:`True`. Controls the ``epsilon`` parameter value for label
If set to ``0.0``, no label smoothing is performed.
smoothing in the loss calculation. If set to 0, no label smoothing is performed.
do_marginalize (:obj:`bool`, `optional`, defaults to :obj:`False`):
do_marginalize (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the logits are marginalized over all documents
If :obj:`True`, the logits are marginalized over all documents
by making use of ``torch.nn.functional.log_softmax``.
by making use of ``torch.nn.functional.log_softmax``.
reduce_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
reduce_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the NLL loss is reduced using the ``torch.Tensor.sum`` operation.
Whether or not to reduce the NLL loss using the ``torch.Tensor.sum`` operation.
do_deduplication (:obj:`bool`, `optional`, defaults to :obj:`True`):
do_deduplication (:obj:`bool`, `optional`, defaults to :obj:`True`):
Controls whether we want to deduplicate the generations from different context documents for a given input.
Whether or not to deduplicate the generations from different context documents for a given input.
Has to be set to :obj:`False` if used while training with distributed backend.
Has to be set to :obj:`False` if used while training with distributed backend.
exclude_bos_score (:obj:`bool`, `optional`, defaults to :obj:`False`):
exclude_bos_score (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, the score of the BOS token is disregarded when computing
Whether or not to disregard the BOS token when computing the loss.
the loss.
output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`):
output_retrieved(:obj:`bool`, `optional`, defaults to :obj:`False`):
If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and :obj:`context_attention_mask` are returned. See returned tensors for more detail.
If set to ``True``, :obj:`retrieved_doc_embeds`, :obj:`retrieved_doc_ids`, :obj:`context_input_ids` and
:obj:`context_attention_mask` are returned. See returned tensors for more detail.