README.md 2.33 KB
Newer Older
1
2
3
4
5
6
7
---
language: en
license: apache-2.0
datasets:
- wiki_dpr
thumbnail: https://huggingface.co/front/thumbnails/facebook.png
---
Patrick von Platen's avatar
Patrick von Platen committed
8
9
## RAG

Patrick von Platen's avatar
Patrick von Platen committed
10
This is a non-finetuned version of the RAG-Token model of the the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/pdf/2005.11401.pdf) 
Patrick von Platen's avatar
Patrick von Platen committed
11
12
by Patrick Lewis, Ethan Perez, Aleksandara Piktus et al.

Patrick von Platen's avatar
Patrick von Platen committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
Rag consits of a *question encoder*, *retriever* and a *generator*. The retriever should be a `RagRetriever` instance. The *question encoder* can be any model that can be loaded with `AutoModel` and the *generator* can be any model that can be loaded with `AutoModelForSeq2SeqLM`. 

This model is a non-finetuned RAG-Token model and was created as follows:

```python
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, AutoTokenizer

model = RagTokenForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large")

question_encoder_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
model.config.use_dummy_dataset = True
model.config.index_name = "exact"
retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer)

model.save_pretrained("./")
tokenizer.save_pretrained("./")
retriever.save_pretrained("./")
```

Note that the model is *uncased* so that all capital input letters are converted to lower-case.

Patrick von Platen's avatar
Patrick von Platen committed
37
38
## Usage:

39
40
*Note*: the model uses the *dummy* retriever as a default. Better results are obtained by using the full retriever, 
by setting `config.index_name="legacy"` and `config.use_dummy_dataset=False`.
Patrick von Platen's avatar
Patrick von Platen committed
41
42
The model can be fine-tuned as follows:

Patrick von Platen's avatar
Patrick von Platen committed
43
```python
Patrick von Platen's avatar
Patrick von Platen committed
44
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
Patrick von Platen's avatar
Patrick von Platen committed
45
46

tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
Patrick von Platen's avatar
Patrick von Platen committed
47
retriever = RagRetriever.from_pretrained("facebook/rag-token-base")
Patrick von Platen's avatar
Patrick von Platen committed
48
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
Patrick von Platen's avatar
Patrick von Platen committed
49

Patrick von Platen's avatar
Patrick von Platen committed
50
51
52
input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", "michael phelps", return_tensors="pt") 

outputs = model(input_dict["input_ids"], labels=input_dict["labels"])
Patrick von Platen's avatar
Patrick von Platen committed
53

Patrick von Platen's avatar
Patrick von Platen committed
54
loss = outputs.loss
Patrick von Platen's avatar
Patrick von Platen committed
55

Patrick von Platen's avatar
Patrick von Platen committed
56
# train on loss
Patrick von Platen's avatar
Patrick von Platen committed
57
```