README.md 2.32 KB
Newer Older
1
2
3
4
---
license: apache-2.0
thumbnail: https://huggingface.co/front/thumbnails/facebook.png
---
Patrick von Platen's avatar
Patrick von Platen committed
5
6
## RAG

Patrick von Platen's avatar
Patrick von Platen committed
7
This is a non-finetuned version of the RAG-Sequence model of the the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks](https://arxiv.org/pdf/2005.11401.pdf) 
Patrick von Platen's avatar
Patrick von Platen committed
8
9
by Patrick Lewis, Ethan Perez, Aleksandara Piktus et al.

Patrick von Platen's avatar
Patrick von Platen committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
Rag consits of a *question encoder*, *retriever* and a *generator*. The retriever should be a `RagRetriever` instance. The *question encoder* can be any model that can be loaded with `AutoModel` and the *generator* can be any model that can be loaded with `AutoModelForSeq2SeqLM`. 

This model is a non-finetuned RAG-Sequence model and was created as follows:

```python
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, AutoTokenizer

model = RagSequenceForGeneration.from_pretrained_question_encoder_generator("facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large")

question_encoder_tokenizer = AutoTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

tokenizer = RagTokenizer(question_encoder_tokenizer, generator_tokenizer)
model.config.use_dummy_dataset = True
model.config.index_name = "exact"
retriever = RagRetriever(model.config, question_encoder_tokenizer, generator_tokenizer)

model.save_pretrained("./")
tokenizer.save_pretrained("./")
retriever.save_pretrained("./")
```

Note that the model is *uncased* so that all capital input letters are converted to lower-case.

Patrick von Platen's avatar
Patrick von Platen committed
34
35
## Usage:

Patrick von Platen's avatar
Patrick von Platen committed
36
37
38
39
*Note*: the model uses the *dummy* retriever as a default. Better results are obtained by using the full retriever, 
by setting `config.index_name="legacy"` and `config.use_dummy_dataset=False`.
The model can be fine-tuned as follows:

Patrick von Platen's avatar
Patrick von Platen committed
40
```python
Patrick von Platen's avatar
Patrick von Platen committed
41
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
Patrick von Platen's avatar
Patrick von Platen committed
42
43

tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
Patrick von Platen's avatar
Patrick von Platen committed
44
45
46
47
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base")
model = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)

input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", "michael phelps", return_tensors="pt") 
Patrick von Platen's avatar
Patrick von Platen committed
48

Patrick von Platen's avatar
Patrick von Platen committed
49
outputs = model(input_dict["input_ids"], labels=input_dict["labels"])
Patrick von Platen's avatar
Patrick von Platen committed
50

Patrick von Platen's avatar
Patrick von Platen committed
51
loss = outputs.loss
Patrick von Platen's avatar
Patrick von Platen committed
52

Patrick von Platen's avatar
Patrick von Platen committed
53
# train on loss
Patrick von Platen's avatar
Patrick von Platen committed
54
```