test_pipelines_fill_mask.py 11 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
18
19
20
21
22
23
24
import unittest

from transformers import pipeline
from transformers.testing_utils import require_tf, require_torch, slow

from .test_pipelines_common import MonoInputPipelineCommonMixin


EXPECTED_FILL_MASK_RESULT = [
    [
25
26
        {"sequence": "My name is John", "score": 0.00782308354973793, "token": 610, "token_str": " John"},
        {"sequence": "My name is Chris", "score": 0.007475061342120171, "token": 1573, "token_str": " Chris"},
27
28
29
    ],
    [
        {
30
31
32
33
            "sequence": "The largest city in France is Paris",
            "score": 0.2510891854763031,
            "token": 2201,
            "token_str": " Paris",
34
35
        },
        {
36
37
38
39
            "sequence": "The largest city in France is Lyon",
            "score": 0.21418564021587372,
            "token": 12790,
            "token_str": " Lyon",
40
        },
41
    ],
42
43
]

44
45
EXPECTED_FILL_MASK_TARGET_RESULT = [EXPECTED_FILL_MASK_RESULT[0]]

46
47
48

class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
    pipeline_task = "fill-mask"
49
    pipeline_loading_kwargs = {"top_k": 2}
50
51
52
53
54
55
56
57
58
59
60
61
62
    small_models = ["sshleifer/tiny-distilroberta-base"]  # Models tested without the @slow decorator
    large_models = ["distilroberta-base"]  # Models tested with the @slow decorator
    mandatory_keys = {"sequence", "score", "token"}
    valid_inputs = [
        "My name is <mask>",
        "The largest city in France is <mask>",
    ]
    invalid_inputs = [
        "This is <mask> <mask>"  # More than 1 mask_token in the input is not supported
        "This is"  # No mask_token is not supported
    ]
    expected_check_keys = ["sequence"]

63
64
65
    @require_torch
    def test_torch_fill_mask(self):
        valid_inputs = "My name is <mask>"
66
67
        unmasker = pipeline(task="fill-mask", model=self.small_models[0])
        outputs = unmasker(valid_inputs)
68
69
70
        self.assertIsInstance(outputs, list)

        # This passes
71
        outputs = unmasker(valid_inputs, targets=[" Patrick", " Clara"])
72
73
74
        self.assertIsInstance(outputs, list)

        # This used to fail with `cannot mix args and kwargs`
75
        outputs = unmasker(valid_inputs, something=False)
76
77
        self.assertIsInstance(outputs, list)

78
79
80
    @require_torch
    def test_torch_fill_mask_with_targets(self):
        valid_inputs = ["My name is <mask>"]
81
82
        # ' Sam' will yield a warning but work
        valid_targets = [[" Teven", "臓Patrick", "臓Clara"], ["臓Sam"], [" Sam"]]
83
84
        invalid_targets = [[], [""], ""]
        for model_name in self.small_models:
85
            unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
86
            for targets in valid_targets:
87
                outputs = unmasker(valid_inputs, targets=targets)
88
89
90
                self.assertIsInstance(outputs, list)
                self.assertEqual(len(outputs), len(targets))
            for targets in invalid_targets:
91
                self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
92

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    @require_torch
    def test_torch_fill_mask_with_targets_and_topk(self):
        model_name = self.small_models[0]
        unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
        targets = [" Teven", "臓Patrick", "臓Clara"]
        top_k = 2
        outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)

        self.assertEqual(len(outputs), 2)

    @require_torch
    def test_torch_fill_mask_with_duplicate_targets_and_topk(self):
        model_name = self.small_models[0]
        unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
        # String duplicates + id duplicates
        targets = [" Teven", "臓Patrick", "臓Clara", "臓Clara", " Clara"]
        top_k = 10
        outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)

        # The target list contains duplicates, so we can't output more
        # than them
        self.assertEqual(len(outputs), 3)

116
117
118
    @require_tf
    def test_tf_fill_mask_with_targets(self):
        valid_inputs = ["My name is <mask>"]
119
120
        # ' Sam' will yield a warning but work
        valid_targets = [[" Teven", "臓Patrick", "臓Clara"], ["臓Sam"], [" Sam"]]
121
122
        invalid_targets = [[], [""], ""]
        for model_name in self.small_models:
123
            unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
124
            for targets in valid_targets:
125
                outputs = unmasker(valid_inputs, targets=targets)
126
127
128
                self.assertIsInstance(outputs, list)
                self.assertEqual(len(outputs), len(targets))
            for targets in invalid_targets:
129
                self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
130
131
132
133
134
135
136
137
138

    @require_torch
    @slow
    def test_torch_fill_mask_results(self):
        mandatory_keys = {"sequence", "score", "token"}
        valid_inputs = [
            "My name is <mask>",
            "The largest city in France is <mask>",
        ]
139
        valid_targets = ["臓Patrick", "臓Clara"]
140
        for model_name in self.large_models:
141
            unmasker = pipeline(
142
143
144
145
                task="fill-mask",
                model=model_name,
                tokenizer=model_name,
                framework="pt",
146
                top_k=2,
147
            )
Lysandre Debut's avatar
Lysandre Debut committed
148

149
            mono_result = unmasker(valid_inputs[0], targets=valid_targets)
Lysandre Debut's avatar
Lysandre Debut committed
150
151
152
153
154
155
            self.assertIsInstance(mono_result, list)
            self.assertIsInstance(mono_result[0], dict)

            for mandatory_key in mandatory_keys:
                self.assertIn(mandatory_key, mono_result[0])

156
            multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
Lysandre Debut's avatar
Lysandre Debut committed
157
158
159
160
            self.assertIsInstance(multi_result, list)
            self.assertIsInstance(multi_result[0], (dict, list))

            for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
161
162
163
164
165
                for r, e in zip(result, expected):
                    self.assertEqual(r["sequence"], e["sequence"])
                    self.assertEqual(r["token_str"], e["token_str"])
                    self.assertEqual(r["token"], e["token"])
                    self.assertAlmostEqual(r["score"], e["score"], places=3)
Lysandre Debut's avatar
Lysandre Debut committed
166
167
168
169
170
171
172
173

            if isinstance(multi_result[0], list):
                multi_result = multi_result[0]

            for result in multi_result:
                for key in mandatory_keys:
                    self.assertIn(key, result)

174
            self.assertRaises(Exception, unmasker, [None])
Lysandre Debut's avatar
Lysandre Debut committed
175
176

            valid_inputs = valid_inputs[:1]
177
            mono_result = unmasker(valid_inputs[0], targets=valid_targets)
Lysandre Debut's avatar
Lysandre Debut committed
178
179
180
181
182
183
            self.assertIsInstance(mono_result, list)
            self.assertIsInstance(mono_result[0], dict)

            for mandatory_key in mandatory_keys:
                self.assertIn(mandatory_key, mono_result[0])

184
            multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
Lysandre Debut's avatar
Lysandre Debut committed
185
186
187
188
            self.assertIsInstance(multi_result, list)
            self.assertIsInstance(multi_result[0], (dict, list))

            for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
189
190
191
192
193
                for r, e in zip(result, expected):
                    self.assertEqual(r["sequence"], e["sequence"])
                    self.assertEqual(r["token_str"], e["token_str"])
                    self.assertEqual(r["token"], e["token"])
                    self.assertAlmostEqual(r["score"], e["score"], places=3)
Lysandre Debut's avatar
Lysandre Debut committed
194
195
196
197
198
199
200
201

            if isinstance(multi_result[0], list):
                multi_result = multi_result[0]

            for result in multi_result:
                for key in mandatory_keys:
                    self.assertIn(key, result)

202
            self.assertRaises(Exception, unmasker, [None])
203
204
205
206
207
208
209
210
211

    @require_tf
    @slow
    def test_tf_fill_mask_results(self):
        mandatory_keys = {"sequence", "score", "token"}
        valid_inputs = [
            "My name is <mask>",
            "The largest city in France is <mask>",
        ]
212
        valid_targets = ["臓Patrick", "臓Clara"]
213
        for model_name in self.large_models:
214
            unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
Lysandre Debut's avatar
Lysandre Debut committed
215

216
            mono_result = unmasker(valid_inputs[0], targets=valid_targets)
Lysandre Debut's avatar
Lysandre Debut committed
217
218
219
220
221
222
            self.assertIsInstance(mono_result, list)
            self.assertIsInstance(mono_result[0], dict)

            for mandatory_key in mandatory_keys:
                self.assertIn(mandatory_key, mono_result[0])

223
            multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
Lysandre Debut's avatar
Lysandre Debut committed
224
225
226
227
            self.assertIsInstance(multi_result, list)
            self.assertIsInstance(multi_result[0], (dict, list))

            for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
228
229
230
231
232
                for r, e in zip(result, expected):
                    self.assertEqual(r["sequence"], e["sequence"])
                    self.assertEqual(r["token_str"], e["token_str"])
                    self.assertEqual(r["token"], e["token"])
                    self.assertAlmostEqual(r["score"], e["score"], places=3)
Lysandre Debut's avatar
Lysandre Debut committed
233
234
235
236
237
238
239
240

            if isinstance(multi_result[0], list):
                multi_result = multi_result[0]

            for result in multi_result:
                for key in mandatory_keys:
                    self.assertIn(key, result)

241
            self.assertRaises(Exception, unmasker, [None])
Lysandre Debut's avatar
Lysandre Debut committed
242
243

            valid_inputs = valid_inputs[:1]
244
            mono_result = unmasker(valid_inputs[0], targets=valid_targets)
Lysandre Debut's avatar
Lysandre Debut committed
245
246
247
248
249
250
            self.assertIsInstance(mono_result, list)
            self.assertIsInstance(mono_result[0], dict)

            for mandatory_key in mandatory_keys:
                self.assertIn(mandatory_key, mono_result[0])

251
            multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
Lysandre Debut's avatar
Lysandre Debut committed
252
253
254
255
            self.assertIsInstance(multi_result, list)
            self.assertIsInstance(multi_result[0], (dict, list))

            for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
256
257
258
259
260
                for r, e in zip(result, expected):
                    self.assertEqual(r["sequence"], e["sequence"])
                    self.assertEqual(r["token_str"], e["token_str"])
                    self.assertEqual(r["token"], e["token"])
                    self.assertAlmostEqual(r["score"], e["score"], places=3)
Lysandre Debut's avatar
Lysandre Debut committed
261
262
263
264
265
266
267
268

            if isinstance(multi_result[0], list):
                multi_result = multi_result[0]

            for result in multi_result:
                for key in mandatory_keys:
                    self.assertIn(key, result)

269
            self.assertRaises(Exception, unmasker, [None])