test_modeling_mbart.py 7.35 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
18
import unittest

from transformers import is_torch_available
from transformers.file_utils import cached_property
19
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
20

21
from .test_modeling_common import ModelTesterMixin
22
23
24
25


if is_torch_available():
    import torch
26

27
28
    from transformers import (
        AutoModelForSeq2SeqLM,
29
30
        AutoTokenizer,
        BatchEncoding,
31
32
        MBartConfig,
        MBartForConditionalGeneration,
33
34
35
36
37
38
39
    )


EN_CODE = 250004
RO_CODE = 250020


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@require_torch
class ModelTester:
    def __init__(self, parent):
        self.config = MBartConfig(
            vocab_size=99,
            d_model=24,
            encoder_layers=2,
            decoder_layers=2,
            encoder_attention_heads=2,
            decoder_attention_heads=2,
            encoder_ffn_dim=32,
            decoder_ffn_dim=32,
            max_position_embeddings=48,
            add_final_layer_norm=True,
        )

    def prepare_config_and_inputs_for_common(self):
        return self.config, {}


@require_torch
class SelectiveCommonTest(unittest.TestCase):
    all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()

64
    test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
65
66
67
68
69

    def setUp(self):
        self.model_tester = ModelTester(self)


70
@require_torch
71
72
@require_sentencepiece
@require_tokenizers
73
74
class AbstractSeq2SeqIntegrationTest(unittest.TestCase):
    maxDiff = 1000  # longer string compare tracebacks
75
76
77
78
    checkpoint_name = None

    @classmethod
    def setUpClass(cls):
Lysandre Debut's avatar
Lysandre Debut committed
79
        cls.tokenizer = AutoTokenizer.from_pretrained(cls.checkpoint_name, use_fast=False)
80
81
82
83
84
85
86
87
88
89
90
91
        return cls

    @cached_property
    def model(self):
        """Only load the model if needed."""
        model = AutoModelForSeq2SeqLM.from_pretrained(self.checkpoint_name).to(torch_device)
        if "cuda" in torch_device:
            model = model.half()
        return model


@require_torch
92
93
@require_sentencepiece
@require_tokenizers
94
class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
95
96
97
98
99
100
101
    checkpoint_name = "facebook/mbart-large-en-ro"
    src_text = [
        " UN Chief Says There Is No Military Solution in Syria",
        """ Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that "there is no military solution" to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.""",
    ]
    tgt_text = [
        "艦eful ONU declar膬 c膬 nu exist膬 o solu牛ie militar膬 卯n Siria",
102
        'Secretarul General Ban Ki-moon declar膬 c膬 r膬spunsul s膬u la intensificarea sprijinului militar al Rusiei pentru Siria este c膬 "nu exist膬 o solu牛ie militar膬" la conflictul de aproape cinci ani 艧i c膬 noi arme nu vor face dec芒t s膬 卯nr膬ut膬牛easc膬 violen牛a 艧i mizeria pentru milioane de oameni.',
103
104
105
    ]
    expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, EN_CODE]

Sam Shleifer's avatar
Sam Shleifer committed
106
107
108
    @slow
    def test_enro_generate_one(self):
        batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(
109
            ["UN Chief Says There Is No Military Solution in Syria"], return_tensors="pt"
Sam Shleifer's avatar
Sam Shleifer committed
110
111
112
113
114
        ).to(torch_device)
        translated_tokens = self.model.generate(**batch)
        decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
        self.assertEqual(self.tgt_text[0], decoded[0])
        # self.assertEqual(self.tgt_text[1], decoded[1])
115
116

    @slow
Sam Shleifer's avatar
Sam Shleifer committed
117
    def test_enro_generate_batch(self):
118
119
120
        batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text, return_tensors="pt").to(
            torch_device
        )
121
122
        translated_tokens = self.model.generate(**batch)
        decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
123
        assert self.tgt_text == decoded
124
125
126
127
128

    def test_mbart_enro_config(self):
        mbart_models = ["facebook/mbart-large-en-ro"]
        expected = {"scale_embedding": True, "output_past": True}
        for name in mbart_models:
129
            config = MBartConfig.from_pretrained(name)
130
131
132
133
134
135
136
137
138
            self.assertTrue(config.is_valid_mbart())
            for k, v in expected.items():
                try:
                    self.assertEqual(v, getattr(config, k))
                except AssertionError as e:
                    e.args += (name, k)
                    raise

    def test_mbart_fast_forward(self):
139
        config = MBartConfig(
140
141
142
143
144
145
146
147
148
149
150
            vocab_size=99,
            d_model=24,
            encoder_layers=2,
            decoder_layers=2,
            encoder_attention_heads=2,
            decoder_attention_heads=2,
            encoder_ffn_dim=32,
            decoder_ffn_dim=32,
            max_position_embeddings=48,
            add_final_layer_norm=True,
        )
151
        lm_model = MBartForConditionalGeneration(config).to(torch_device)
152
153
        context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
        summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
Sylvain Gugger's avatar
Sylvain Gugger committed
154
        result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
155
        expected_shape = (*summary.shape, config.vocab_size)
156
        self.assertEqual(result.logits.shape, expected_shape)
157
158


159
@require_torch
160
161
@require_sentencepiece
@require_tokenizers
162
class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
163
164
165
166
167
168
169
170
171
    checkpoint_name = "facebook/mbart-large-cc25"
    src_text = [
        " UN Chief Says There Is No Military Solution in Syria",
        " I ate lunch twice yesterday",
    ]
    tgt_text = ["艦eful ONU declar膬 c膬 nu exist膬 o solu牛ie militar膬 卯n Siria", "to be padded"]

    @unittest.skip("This test is broken, still generates english")
    def test_cc25_generate(self):
172
        inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]], return_tensors="pt").to(torch_device)
173
174
175
176
177
178
        translated_tokens = self.model.generate(
            input_ids=inputs["input_ids"].to(torch_device),
            decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
        )
        decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
        self.assertEqual(self.tgt_text[0], decoded[0])
179
180
181

    @slow
    def test_fill_mask(self):
182
183
184
        inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"], return_tensors="pt").to(
            torch_device
        )
185
186
187
188
189
190
191
        outputs = self.model.generate(
            inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
        )
        prediction: str = self.tokenizer.batch_decode(
            outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True
        )[0]
        self.assertEqual(prediction, "of the best books I ever read!")