README.md 1.61 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
---
language: en
license: apache-2.0
datasets:
- wiki_dpr
---
Patrick von Platen's avatar
Patrick von Platen committed
7
8
9
## RAG

This is the RAG-Token Model of the the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/pdf/2005.11401.pdf) 
Patrick von Platen's avatar
Patrick von Platen committed
10
by Patrick Lewis, Ethan Perez, Aleksandara Piktus et al.
Patrick von Platen's avatar
Patrick von Platen committed
11

Patrick von Platen's avatar
Patrick von Platen committed
12
13
14
15
16
17
The model is a *uncased* model, which means that capital letters are simply converted to lower-case letters.

The model consits of a *question_encoder*, *retriever* and a *generator*. The retriever is extracts relevant passages from the *wiki_dpr* `train` datasets, which is linked above.
The question_encoder and retriever are based on `facebook/dpr-question_encoder-single-nq-base` and `facebook/bart-large`, which were jointly finetuned on 
on the *wiki_dpr* QA dataset in an end-to-end fashion.

Patrick von Platen's avatar
Patrick von Platen committed
18
19
## Usage:

Patrick von Platen's avatar
Patrick von Platen committed
20
21
**Note**: In the usage example below only the *dummy* retriever of *wiki_dpr* is used because the real retriever requires to over 40 GB of RAM.
The model can generate questions to any question as follows:
Patrick von Platen's avatar
Patrick von Platen committed
22

Patrick von Platen's avatar
Patrick von Platen committed
23
```python
Patrick von Platen's avatar
Patrick von Platen committed
24
25
26
27
28
29
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration

tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

Patrick von Platen's avatar
Patrick von Platen committed
30
input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", return_tensors="pt") 
Patrick von Platen's avatar
Patrick von Platen committed
31

Patrick von Platen's avatar
Patrick von Platen committed
32
33
generated = model.generate(input_ids=input_dict["input_ids"]) 
print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0]) 
Patrick von Platen's avatar
Patrick von Platen committed
34

Patrick von Platen's avatar
Patrick von Platen committed
35
# should give michael phelps => sounds reasonable
Patrick von Platen's avatar
Patrick von Platen committed
36
```