test_modeling_rag.py 33.5 KB
Newer Older
Ola Piktus's avatar
Ola Piktus committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# coding=utf-8
# Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import os
import shutil
import tempfile
import unittest
from unittest.mock import patch

import numpy as np

26
from transformers import BartTokenizer, T5Tokenizer
Ola Piktus's avatar
Ola Piktus committed
27
from transformers.file_utils import cached_property, is_datasets_available, is_faiss_available, is_torch_available
28
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
Ola Piktus's avatar
Ola Piktus committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES

from .test_modeling_bart import ModelTester as BartModelTester
from .test_modeling_dpr import DPRModelTester
from .test_modeling_t5 import T5ModelTester


TOLERANCE = 1e-3

T5_SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")

if is_torch_available() and is_datasets_available() and is_faiss_available():
    import torch
    from datasets import Dataset

    import faiss
    from transformers import (
        AutoConfig,
        AutoModel,
        AutoModelForSeq2SeqLM,
        RagConfig,
        RagModel,
        RagRetriever,
        RagSequenceForGeneration,
        RagTokenForGeneration,
56
        RagTokenizer,
Ola Piktus's avatar
Ola Piktus committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    )
    from transformers.modeling_outputs import BaseModelOutput


def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
    """If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
    if a is None and b is None:
        return True
    try:
        if torch.allclose(a, b, atol=atol):
            return True
        raise
    except Exception:
        msg = "{} != {}".format(a, b)
        if prefix:
            msg = prefix + ": " + msg
        raise AssertionError(msg)


def require_retrieval(test_case):
    """
    Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
    :class:`~transformers.RagRetriever`.

    These tests are skipped when respective libraries are not installed.

    """
    if not (is_torch_available() and is_datasets_available() and is_faiss_available()):
        test_case = unittest.skip("test requires PyTorch")(test_case)
    return test_case


@require_torch
@require_retrieval
91
@require_sentencepiece
Ola Piktus's avatar
Ola Piktus committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
class RagTestMixin:

    all_model_classes = (
        (RagModel, RagTokenForGeneration, RagSequenceForGeneration)
        if is_torch_available() and is_datasets_available() and is_faiss_available()
        else ()
    )

    retrieval_vector_size = 32
    n_docs = 2
    max_combined_length = 16

    def setUp(self):
        self.tmpdirname = tempfile.mkdtemp()

        # DPR tok
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "[PAD]",
            "[MASK]",
            "want",
            "##want",
            "##ed",
            "wa",
            "un",
            "runn",
            "##ing",
            ",",
            "low",
            "lowest",
        ]
        dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
        os.makedirs(dpr_tokenizer_path, exist_ok=True)
        self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

        # BART tok
        vocab = [
            "l",
            "o",
            "w",
            "e",
            "r",
            "s",
            "t",
            "i",
            "d",
            "n",
            "\u0120",
            "\u0120l",
            "\u0120n",
            "\u0120lo",
            "\u0120low",
            "er",
            "\u0120lowest",
            "\u0120newer",
            "\u0120wider",
            "<unk>",
        ]
        vocab_tokens = dict(zip(vocab, range(len(vocab))))
        merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
        self.special_tokens_map = {"unk_token": "<unk>"}

        bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
        os.makedirs(bart_tokenizer_path, exist_ok=True)
        self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
        self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as fp:
            fp.write(json.dumps(vocab_tokens) + "\n")
        with open(self.merges_file, "w", encoding="utf-8") as fp:
            fp.write("\n".join(merges))

        t5_tokenizer = T5Tokenizer(T5_SAMPLE_VOCAB)
        t5_tokenizer_path = os.path.join(self.tmpdirname, "t5_tokenizer")
        t5_tokenizer.save_pretrained(t5_tokenizer_path)

    @cached_property
    def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
        return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))

    @cached_property
    def bart_tokenizer(self) -> BartTokenizer:
        return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))

    @cached_property
    def t5_tokenizer(self) -> BartTokenizer:
        return T5Tokenizer.from_pretrained(os.path.join(self.tmpdirname, "t5_tokenizer"))

    def tearDown(self):
        shutil.rmtree(self.tmpdirname)

    def get_retriever(self, config):
        dataset = Dataset.from_dict(
            {
                "id": ["0", "1"],
                "text": ["foo", "bar"],
                "title": ["Foo", "Bar"],
                "embeddings": [np.ones(self.retrieval_vector_size), 2 * np.ones(self.retrieval_vector_size)],
            }
        )
        dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
        tokenizer = self.bart_tokenizer if config.generator.model_type == "bart" else self.t5_tokenizer
        with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
            mock_load_dataset.return_value = dataset
            retriever = RagRetriever(
                config,
                question_encoder_tokenizer=self.dpr_tokenizer,
                generator_tokenizer=tokenizer,
            )
        return retriever

    def check_model_with_retriever(
        self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
    ):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes:
            model = model_class(config, retriever=self.get_retriever(config)).to(torch_device)
            model.eval()

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))

    def check_model_generate(
        self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
    ):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes[1:]:
            model = model_class(config, retriever=self.get_retriever(config)).to(torch_device)
            model.eval()

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model.generate(
                input_ids=input_ids,
                num_beams=2,
                num_return_sequences=2,
                decoder_start_token_id=config.generator.eos_token_id,
            )

            self.assertIsNotNone(outputs)

    def check_model_without_retriever(
        self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
    ):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        retriever = self.get_retriever(config)

        for model_class in self.all_model_classes:
            model = model_class(config).to(torch_device)
            model.eval()
            self.assertTrue(model.config.is_encoder_decoder)

            question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]

            out = retriever(
                input_ids,
                question_hidden_states.cpu().detach().to(torch.float32).numpy(),
                prefix=config.generator.prefix,
                return_tensors="pt",
            )

            context_input_ids, context_attention_mask, retrieved_doc_embeds = (
                out["context_input_ids"],
                out["context_attention_mask"],
                out["retrieved_doc_embeds"],
            )

            # cast
            retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
            context_input_ids = context_input_ids.to(input_ids)
            context_attention_mask = context_attention_mask.to(input_ids)

            # compute doc_scores
            doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
                1
            )

            outputs = model(
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                doc_scores=doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))

    def check_model_with_encoder_outputs(
        self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
    ):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes:
            model = model_class(config, retriever=self.get_retriever(config)).to(torch_device)
            model.eval()

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            encoder_outputs = BaseModelOutput(outputs.generator_enc_last_hidden_state)

            # run only generator
            outputs = model(
                encoder_outputs=encoder_outputs,
                doc_scores=outputs.doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))

    def test_model_with_retriever(self):
        inputs_dict = self.config_and_inputs
        self.check_model_with_retriever(**inputs_dict)

    def test_model_without_retriever(self):
        inputs_dict = self.config_and_inputs
        self.check_model_without_retriever(**inputs_dict)

    def test_model_with_encoder_outputs(self):
        inputs_dict = self.config_and_inputs
        self.check_model_with_encoder_outputs(**inputs_dict)

    def test_model_generate(self):
        inputs_dict = self.config_and_inputs
        self.check_model_generate(**inputs_dict)


@require_torch
@require_retrieval
class RagDPRBartTest(RagTestMixin, unittest.TestCase):
    @cached_property
    def config_and_inputs(self):
        question_encoder_tester = DPRModelTester(self)
        dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs()
        generator_tester = BartModelTester(self)
        bart_config_and_inputs = generator_tester.prepare_config_and_inputs_for_common()

        (question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
        (generator_config, bart_inputs_dict) = bart_config_and_inputs
        decoder_input_ids, decoder_attention_mask = bart_inputs_dict["input_ids"], bart_inputs_dict["attention_mask"]

        config = RagConfig.from_question_encoder_generator_configs(
            question_encoder_config,
            generator_config,
            n_docs=self.n_docs,
            retrieval_vector_size=self.retrieval_vector_size,
            max_combined_length=self.max_combined_length,
            use_cache=False,
        )

        return {
            "config": config,
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "decoder_input_ids": decoder_input_ids,
            "decoder_attention_mask": decoder_attention_mask,
        }


@require_torch
@require_retrieval
class RagDPRT5Test(RagTestMixin, unittest.TestCase):
    @cached_property
    def config_and_inputs(self):
        question_encoder_tester = DPRModelTester(self)
        dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs()
        generator_tester = T5ModelTester(self, vocab_size=1100, n_positions=30)
        t5_config_and_inputs = generator_tester.prepare_config_and_inputs()

        (question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
        (generator_config, _, decoder_input_ids, _, decoder_attention_mask, _) = t5_config_and_inputs
        config = RagConfig.from_question_encoder_generator_configs(
            question_encoder_config,
            generator_config,
            n_docs=self.n_docs,
            retrieval_vector_size=self.retrieval_vector_size,
            max_combined_length=self.max_combined_length,
            use_cache=False,
        )

        return {
            "config": config,
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "decoder_input_ids": decoder_input_ids,
            "decoder_attention_mask": decoder_attention_mask,
        }


@require_torch
@require_retrieval
441
442
@require_sentencepiece
@require_tokenizers
Ola Piktus's avatar
Ola Piktus committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
class RagModelIntegrationTests(unittest.TestCase):
    @cached_property
    def sequence_model(self):
        return (
            RagSequenceForGeneration.from_pretrained_question_encoder_generator(
                "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn"
            )
            .to(torch_device)
            .eval()
        )

    @cached_property
    def token_model(self):
        return (
            RagTokenForGeneration.from_pretrained_question_encoder_generator(
                "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn"
            )
            .to(torch_device)
            .eval()
        )

    def get_rag_config(self):
        question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
        return RagConfig.from_question_encoder_generator_configs(
            question_encoder_config,
            generator_config,
            bos_token_id=0,
            decoder_start_token_id=2,
            eos_token_id=2,
            is_encoder_decoder=True,
            pad_token_id=1,
            vocab_size=50264,
            title_sep=" / ",
            doc_sep=" // ",
            n_docs=5,
            max_combined_length=300,
            dataset="wiki_dpr",
            dataset_split="train",
            index_name="exact",
            index_path=None,
            use_dummy_dataset=True,
            retrieval_vector_size=768,
            retrieval_batch_size=8,
        )

    @slow
    def test_rag_sequence_inference(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_sequence = self.sequence_model
        rag_sequence.set_retriever(rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="pt"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids

        input_ids = input_ids.to(torch_device)
        decoder_input_ids = decoder_input_ids.to(torch_device)

        with torch.no_grad():
            output = rag_sequence(
                input_ids,
                labels=decoder_input_ids,
            )

        expected_shape = torch.Size([5, 5, 50264])
        self.assertEqual(output.logits.shape, expected_shape)

        expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
        _assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)

525
        expected_loss = torch.tensor([36.7368]).to(torch_device)
Ola Piktus's avatar
Ola Piktus committed
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
        _assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)

    @slow
    def test_rag_token_inference(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_token = self.token_model
        rag_token.set_retriever(rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="pt"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids

        input_ids = input_ids.to(torch_device)
        decoder_input_ids = decoder_input_ids.to(torch_device)

        with torch.no_grad():
            output = rag_token(
                input_ids,
                labels=decoder_input_ids,
            )

        expected_shape = torch.Size([5, 5, 50264])
        self.assertEqual(output.logits.shape, expected_shape)

        expected_doc_scores = torch.tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]]).to(torch_device)
        _assert_tensors_equal(expected_doc_scores, output.doc_scores, atol=TOLERANCE)

564
        expected_loss = torch.tensor([36.3557]).to(torch_device)
Ola Piktus's avatar
Ola Piktus committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        _assert_tensors_equal(expected_loss, output.loss, atol=TOLERANCE)

    @slow
    def test_rag_token_generate_beam(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_token = self.token_model
        rag_token.set_retriever(rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="pt"
        ).input_ids

        input_ids = input_ids.to(torch_device)

        output_ids = rag_token.generate(
            input_ids,
            decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
            num_beams=2,
            num_return_sequences=2,
        )
        # sequence generate test
        output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
        output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)

        # Expected outputs as given by model at integration time.
600
601
        EXPECTED_OUTPUT_TEXT_1 = "\"She's My Kind of Girl"
        EXPECTED_OUTPUT_TEXT_2 = "\"She's My Kind of Love"
Ola Piktus's avatar
Ola Piktus committed
602
603
604
605
606

        self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
        self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)

    @slow
607
    def test_rag_sequence_generate_beam(self):
Ola Piktus's avatar
Ola Piktus committed
608
609
610
611
612
613
614
615
616
617
618
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

619
        rag_token = self.sequence_model
Ola Piktus's avatar
Ola Piktus committed
620
621
        rag_token.set_retriever(rag_retriever)

622
623
624
        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="pt"
        ).input_ids
Ola Piktus's avatar
Ola Piktus committed
625

626
        input_ids = input_ids.to(torch_device)
Ola Piktus's avatar
Ola Piktus committed
627
628
629
630

        output_ids = rag_token.generate(
            input_ids,
            decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
631
632
            num_beams=2,
            num_return_sequences=2,
Ola Piktus's avatar
Ola Piktus committed
633
634
635
636
637
638
        )
        # sequence generate test
        output_text_1 = rag_decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
        output_text_2 = rag_decoder_tokenizer.decode(output_ids[1], skip_special_tokens=True)

        # Expected outputs as given by model at integration time.
639
640
        EXPECTED_OUTPUT_TEXT_1 = """\"She's My Kind of Girl\" was released through Epic Records in Japan in March 1972, giving the duo a Top 10 hit. Two more singles were released in Japan, \"En Carousel\" and \"Love Has Its Ways\" Ulvaeus and Andersson persevered with their songwriting and experimented with new sounds and vocal arrangements."""
        EXPECTED_OUTPUT_TEXT_2 = """In September 2018, Bj枚rn Ulvaeus revealed that the two new songs, \"I Still Have Faith In You\" and \"Don't Shut Me Down\", would be released no earlier than March 2019. The two new tracks will feature in a TV special set to air later in the year."""
Ola Piktus's avatar
Ola Piktus committed
641
642
643

        self.assertEqual(output_text_1, EXPECTED_OUTPUT_TEXT_1)
        self.assertEqual(output_text_2, EXPECTED_OUTPUT_TEXT_2)
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663

    @property
    def test_data_questions(self):
        return [
            "who got the first nobel prize in physics",
            "when is the next deadpool movie being released",
            "which mode is used for short wave broadcast service",
            "who is the owner of reading football club",
            "when is the next scandal episode coming out",
            "when is the last time the philadelphia won the superbowl",
            "what is the most current adobe flash player version",
            "how many episodes are there in dragon ball z",
            "what is the first step in the evolution of the eye",
            "where is gall bladder situated in human body",
            "what is the main mineral in lithium batteries",
            "who is the president of usa right now",
            "where do the greasers live in the outsiders",
            "panda is a national animal of which country",
            "what is the name of manchester united stadium",
        ]
Ola Piktus's avatar
Ola Piktus committed
664
665
666

    @slow
    def test_rag_sequence_generate_batch(self):
667
668
669
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
Ola Piktus's avatar
Ola Piktus committed
670
        )
671
672
        rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
            torch_device
Ola Piktus's avatar
Ola Piktus committed
673
674
        )

675
676
        input_dict = tokenizer(
            self.test_data_questions,
Ola Piktus's avatar
Ola Piktus committed
677
678
679
            return_tensors="pt",
            padding=True,
            truncation=True,
680
        )
Ola Piktus's avatar
Ola Piktus committed
681

682
683
        input_ids = input_dict.input_ids.to(torch_device)
        attention_mask = input_dict.attention_mask.to(torch_device)
Ola Piktus's avatar
Ola Piktus committed
684
685
686

        output_ids = rag_sequence.generate(
            input_ids,
687
            attention_mask=attention_mask,
Ola Piktus's avatar
Ola Piktus committed
688
689
        )

690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
            " reticular formation",
            " walls of the abdomen",
            " spodumene",
            " obama",
            " grainger's compound",
            " japan",
            " old trafford stadium",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ola Piktus's avatar
Ola Piktus committed
710
711

    @slow
712
713
714
715
716
    def test_rag_token_generate_batch(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
        rag_token = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever).to(
            torch_device
Ola Piktus's avatar
Ola Piktus committed
717
718
        )

719
720
721
722
723
724
        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )
Ola Piktus's avatar
Ola Piktus committed
725

726
727
        input_ids = input_dict.input_ids.to(torch_device)
        attention_mask = input_dict.attention_mask.to(torch_device)
Ola Piktus's avatar
Ola Piktus committed
728
729
730

        output_ids = rag_token.generate(
            input_ids,
731
            attention_mask=attention_mask,
Ola Piktus's avatar
Ola Piktus committed
732
733
        )

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
            " amplitude modulation",
            " stefan persson",
            " april 20, 2018",
            " the 1970s",
            " 7.1. 2",
            " 13",
            " step by step",
            " stomach",
            " spodumene",
            " obama",
            " northern new jersey",
            " india",
            " united stadium",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)
Ola Piktus's avatar
Ola Piktus committed
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900


@require_torch
@require_retrieval
class RagModelSaveLoadTests(unittest.TestCase):
    def get_rag_config(self):
        question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
        return RagConfig.from_question_encoder_generator_configs(
            question_encoder_config,
            generator_config,
            bos_token_id=0,
            decoder_start_token_id=2,
            eos_token_id=2,
            is_encoder_decoder=True,
            pad_token_id=1,
            vocab_size=50264,
            title_sep=" / ",
            doc_sep=" // ",
            n_docs=5,
            max_combined_length=300,
            dataset="wiki_dpr",
            dataset_split="train",
            index_name="exact",
            index_path=None,
            use_dummy_dataset=True,
            retrieval_vector_size=768,
            retrieval_batch_size=8,
        )

    @slow
    def test_rag_sequence_from_pretrained(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="pt"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids

        input_ids = input_ids.to(torch_device)
        decoder_input_ids = decoder_input_ids.to(torch_device)

        with tempfile.TemporaryDirectory() as tmp_dirname:
            rag_sequence = RagSequenceForGeneration.from_pretrained_question_encoder_generator(
                "facebook/dpr-question_encoder-single-nq-base",
                "facebook/bart-large-cnn",
                retriever=rag_retriever,
                config=rag_config,
            ).to(torch_device)
            # check that the from pretrained methods work
            rag_sequence.save_pretrained(tmp_dirname)
            rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever)
            rag_sequence.to(torch_device)

            with torch.no_grad():
                output = rag_sequence(
                    input_ids,
                    labels=decoder_input_ids,
                )

            loss_pretrained = output.loss
            del rag_sequence

        question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
        rag_sequence = RagSequenceForGeneration(
            config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever
        )
        rag_sequence.to(torch_device)

        with torch.no_grad():
            output = rag_sequence(
                input_ids,
                labels=decoder_input_ids,
            )

        loss_init = output.loss

        self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4)

    @slow
    def test_rag_token_from_pretrained(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="pt"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="pt").input_ids

        input_ids = input_ids.to(torch_device)
        decoder_input_ids = decoder_input_ids.to(torch_device)

        with tempfile.TemporaryDirectory() as tmp_dirname:
            rag_token = RagTokenForGeneration.from_pretrained_question_encoder_generator(
                "facebook/dpr-question_encoder-single-nq-base",
                "facebook/bart-large-cnn",
                retriever=rag_retriever,
                config=rag_config,
            ).to(torch_device)
            # check that the from pretrained methods work
            rag_token.save_pretrained(tmp_dirname)
            rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
            rag_token.to(torch_device)

            with torch.no_grad():
                output = rag_token(
                    input_ids,
                    labels=decoder_input_ids,
                )

            loss_pretrained = output.loss
            del rag_token

        question_encoder = AutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        generator = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
        rag_token = RagTokenForGeneration(
            config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever
        )
        rag_token.to(torch_device)

        with torch.no_grad():
            output = rag_token(
                input_ids,
                labels=decoder_input_ids,
            )

        loss_init = output.loss

        self.assertAlmostEqual(loss_pretrained.item(), loss_init.item(), places=4)