README.md 1.63 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
---
language: en
license: apache-2.0
datasets:
- wiki_dpr
---
Patrick von Platen's avatar
Patrick von Platen committed
7
8
9
## RAG

This is the RAG-Sequence Model of the the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/pdf/2005.11401.pdf) 
Patrick von Platen's avatar
Patrick von Platen committed
10
by Patrick Lewis, Ethan Perez, Aleksandara Piktus et al.
Patrick von Platen's avatar
Patrick von Platen committed
11

Patrick von Platen's avatar
Patrick von Platen committed
12
The model is a *uncased* model, which means that capital letters are simply converted to lower-case letters.
Patrick von Platen's avatar
Patrick von Platen committed
13

14
The model consits of a *question_encoder*, *retriever* and a *generator*. The retriever extracts relevant passages from the *wiki_dpr* `train` datasets, which is linked above.
Patrick von Platen's avatar
Patrick von Platen committed
15
16
The question_encoder and retriever are based on `facebook/dpr-question_encoder-single-nq-base` and `facebook/bart-large`, which were jointly finetuned on 
on the *wiki_dpr* QA dataset in an end-to-end fashion.
Patrick von Platen's avatar
Patrick von Platen committed
17

Patrick von Platen's avatar
Patrick von Platen committed
18
## Usage:
Patrick von Platen's avatar
Patrick von Platen committed
19

20
21
**Note**: In the usage example below only the *dummy* retriever of *wiki_dpr* is used because the complete *lecagy* index requires over 75 GB of RAM.
The model can generate answers to any factoid question as follows:
Patrick von Platen's avatar
Patrick von Platen committed
22

Patrick von Platen's avatar
Patrick von Platen committed
23
24
25
26
27
28
29
30
```python
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration 
 
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") 
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) 
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) 
 
input_dict = tokenizer.prepare_seq2seq_batch("how many countries are in europe", return_tensors="pt") 
Patrick von Platen's avatar
Patrick von Platen committed
31

Patrick von Platen's avatar
Patrick von Platen committed
32
33
generated = model.generate(input_ids=input_dict["input_ids"]) 
print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0]) 
Patrick von Platen's avatar
Patrick von Platen committed
34

Patrick von Platen's avatar
Patrick von Platen committed
35
# should give 54 => google says either 44 or 51
Patrick von Platen's avatar
Patrick von Platen committed
36
```