README.md 1.69 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
---
language: en
license: apache-2.0
datasets:
- wiki_dpr
6
thumbnail: https://huggingface.co/front/thumbnails/facebook.png
Patrick von Platen's avatar
Patrick von Platen committed
7
---
Patrick von Platen's avatar
Patrick von Platen committed
8
9
10
## RAG

This is the RAG-Sequence Model of the the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/pdf/2005.11401.pdf) 
Patrick von Platen's avatar
Patrick von Platen committed
11
by Patrick Lewis, Ethan Perez, Aleksandara Piktus et al.
Patrick von Platen's avatar
Patrick von Platen committed
12

Patrick von Platen's avatar
Patrick von Platen committed
13
The model is a *uncased* model, which means that capital letters are simply converted to lower-case letters.
Patrick von Platen's avatar
Patrick von Platen committed
14

15
The model consits of a *question_encoder*, *retriever* and a *generator*. The retriever extracts relevant passages from the *wiki_dpr* `train` datasets, which is linked above.
Patrick von Platen's avatar
Patrick von Platen committed
16
17
The question_encoder and retriever are based on `facebook/dpr-question_encoder-single-nq-base` and `facebook/bart-large`, which were jointly finetuned on 
on the *wiki_dpr* QA dataset in an end-to-end fashion.
Patrick von Platen's avatar
Patrick von Platen committed
18

Patrick von Platen's avatar
Patrick von Platen committed
19
## Usage:
Patrick von Platen's avatar
Patrick von Platen committed
20

21
22
**Note**: In the usage example below only the *dummy* retriever of *wiki_dpr* is used because the complete *lecagy* index requires over 75 GB of RAM.
The model can generate answers to any factoid question as follows:
Patrick von Platen's avatar
Patrick von Platen committed
23

Patrick von Platen's avatar
Patrick von Platen committed
24
25
26
27
28
29
30
31
```python
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration 
 
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") 
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) 
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) 
 
input_dict = tokenizer.prepare_seq2seq_batch("how many countries are in europe", return_tensors="pt") 
Patrick von Platen's avatar
Patrick von Platen committed
32

Patrick von Platen's avatar
Patrick von Platen committed
33
34
generated = model.generate(input_ids=input_dict["input_ids"]) 
print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0]) 
Patrick von Platen's avatar
Patrick von Platen committed
35

Patrick von Platen's avatar
Patrick von Platen committed
36
# should give 54 => google says either 44 or 51
Patrick von Platen's avatar
Patrick von Platen committed
37
```