README.md 1.68 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
---
language: en
license: apache-2.0
datasets:
- wiki_dpr
6
thumbnail: https://huggingface.co/front/thumbnails/facebook.png
Patrick von Platen's avatar
Patrick von Platen committed
7
---
Patrick von Platen's avatar
Patrick von Platen committed
8
9
10
## RAG

This is the RAG-Token Model of the the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/pdf/2005.11401.pdf) 
Patrick von Platen's avatar
Patrick von Platen committed
11
by Patrick Lewis, Ethan Perez, Aleksandara Piktus et al.
Patrick von Platen's avatar
Patrick von Platen committed
12

Patrick von Platen's avatar
Patrick von Platen committed
13
14
The model is a *uncased* model, which means that capital letters are simply converted to lower-case letters.

15
The model consits of a *question_encoder*, *retriever* and a *generator*. The retriever extracts relevant passages from the *wiki_dpr* `train` datasets, which is linked above.
Patrick von Platen's avatar
Patrick von Platen committed
16
17
18
The question_encoder and retriever are based on `facebook/dpr-question_encoder-single-nq-base` and `facebook/bart-large`, which were jointly finetuned on 
on the *wiki_dpr* QA dataset in an end-to-end fashion.

Patrick von Platen's avatar
Patrick von Platen committed
19
20
## Usage:

21
22
**Note**: In the usage example below only the *dummy* retriever of *wiki_dpr* is used because the complete *lecagy* index requires over 75 GB of RAM.
The model can generate answers to any factoid question as follows:
Patrick von Platen's avatar
Patrick von Platen committed
23

Patrick von Platen's avatar
Patrick von Platen committed
24
```python
Patrick von Platen's avatar
Patrick von Platen committed
25
26
27
28
29
30
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration

tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

Patrick von Platen's avatar
Patrick von Platen committed
31
input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", return_tensors="pt") 
Patrick von Platen's avatar
Patrick von Platen committed
32

Patrick von Platen's avatar
Patrick von Platen committed
33
34
generated = model.generate(input_ids=input_dict["input_ids"]) 
print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0]) 
Patrick von Platen's avatar
Patrick von Platen committed
35

Patrick von Platen's avatar
Patrick von Platen committed
36
# should give michael phelps => sounds reasonable
Patrick von Platen's avatar
Patrick von Platen committed
37
```