README.md 1.62 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
---
language: en
license: apache-2.0
datasets:
- wiki_dpr
---
Patrick von Platen's avatar
Patrick von Platen committed
7
8
9
## RAG

This is the RAG-Sequence Model of the the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/pdf/2005.11401.pdf) 
Patrick von Platen's avatar
Patrick von Platen committed
10
by Patrick Lewis, Ethan Perez, Aleksandara Piktus et al.
Patrick von Platen's avatar
Patrick von Platen committed
11

Patrick von Platen's avatar
Patrick von Platen committed
12
The model is a *uncased* model, which means that capital letters are simply converted to lower-case letters.
Patrick von Platen's avatar
Patrick von Platen committed
13

Patrick von Platen's avatar
Patrick von Platen committed
14
15
16
The model consits of a *question_encoder*, *retriever* and a *generator*. The retriever is extracts relevant passages from the *wiki_dpr* `train` datasets, which is linked above.
The question_encoder and retriever are based on `facebook/dpr-question_encoder-single-nq-base` and `facebook/bart-large`, which were jointly finetuned on 
on the *wiki_dpr* QA dataset in an end-to-end fashion.
Patrick von Platen's avatar
Patrick von Platen committed
17

Patrick von Platen's avatar
Patrick von Platen committed
18
## Usage:
Patrick von Platen's avatar
Patrick von Platen committed
19

Patrick von Platen's avatar
Patrick von Platen committed
20
21
**Note**: In the usage example below only the *dummy* retriever of *wiki_dpr* is used because the real retriever requires to over 40 GB of RAM.
The model can generate questions to any question as follows:
Patrick von Platen's avatar
Patrick von Platen committed
22

Patrick von Platen's avatar
Patrick von Platen committed
23
```python
Patrick von Platen's avatar
Patrick von Platen committed
24

Patrick von Platen's avatar
Patrick von Platen committed
25
26
27
28
29
30
31
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration 
 
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") 
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) 
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) 
 
input_dict = tokenizer.prepare_seq2seq_batch("how many countries are in europe", return_tensors="pt") 
Patrick von Platen's avatar
Patrick von Platen committed
32

Patrick von Platen's avatar
Patrick von Platen committed
33
34
generated = model.generate(input_ids=input_dict["input_ids"]) 
print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0]) 
Patrick von Platen's avatar
Patrick von Platen committed
35

Patrick von Platen's avatar
Patrick von Platen committed
36
# should give 54 => google says either 44 or 51
Patrick von Platen's avatar
Patrick von Platen committed
37
```