test_pipelines_fill_mask.py 9.89 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
18
19
20
21
22
23
24
import unittest

from transformers import pipeline
from transformers.testing_utils import require_tf, require_torch, slow

from .test_pipelines_common import MonoInputPipelineCommonMixin


EXPECTED_FILL_MASK_RESULT = [
    [
25
26
        {"sequence": "My name is John", "score": 0.00782308354973793, "token": 610, "token_str": " John"},
        {"sequence": "My name is Chris", "score": 0.007475061342120171, "token": 1573, "token_str": " Chris"},
27
28
29
    ],
    [
        {
30
31
32
33
            "sequence": "The largest city in France is Paris",
            "score": 0.2510891854763031,
            "token": 2201,
            "token_str": " Paris",
34
35
        },
        {
36
37
38
39
            "sequence": "The largest city in France is Lyon",
            "score": 0.21418564021587372,
            "token": 12790,
            "token_str": " Lyon",
40
        },
41
    ],
42
43
]

44
45
EXPECTED_FILL_MASK_TARGET_RESULT = [EXPECTED_FILL_MASK_RESULT[0]]

46
47
48

class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
    pipeline_task = "fill-mask"
49
    pipeline_loading_kwargs = {"top_k": 2}
50
51
52
53
54
55
56
57
58
59
60
61
62
    small_models = ["sshleifer/tiny-distilroberta-base"]  # Models tested without the @slow decorator
    large_models = ["distilroberta-base"]  # Models tested with the @slow decorator
    mandatory_keys = {"sequence", "score", "token"}
    valid_inputs = [
        "My name is <mask>",
        "The largest city in France is <mask>",
    ]
    invalid_inputs = [
        "This is <mask> <mask>"  # More than 1 mask_token in the input is not supported
        "This is"  # No mask_token is not supported
    ]
    expected_check_keys = ["sequence"]

63
64
65
    @require_torch
    def test_torch_fill_mask(self):
        valid_inputs = "My name is <mask>"
66
67
        unmasker = pipeline(task="fill-mask", model=self.small_models[0])
        outputs = unmasker(valid_inputs)
68
69
70
        self.assertIsInstance(outputs, list)

        # This passes
71
        outputs = unmasker(valid_inputs, targets=[" Patrick", " Clara"])
72
73
74
        self.assertIsInstance(outputs, list)

        # This used to fail with `cannot mix args and kwargs`
75
        outputs = unmasker(valid_inputs, something=False)
76
77
        self.assertIsInstance(outputs, list)

78
79
80
    @require_torch
    def test_torch_fill_mask_with_targets(self):
        valid_inputs = ["My name is <mask>"]
81
        valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
82
83
        invalid_targets = [[], [""], ""]
        for model_name in self.small_models:
84
            unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
85
            for targets in valid_targets:
86
                outputs = unmasker(valid_inputs, targets=targets)
87
88
89
                self.assertIsInstance(outputs, list)
                self.assertEqual(len(outputs), len(targets))
            for targets in invalid_targets:
90
                self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
91
92
93
94

    @require_tf
    def test_tf_fill_mask_with_targets(self):
        valid_inputs = ["My name is <mask>"]
95
        valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
96
97
        invalid_targets = [[], [""], ""]
        for model_name in self.small_models:
98
            unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
99
            for targets in valid_targets:
100
                outputs = unmasker(valid_inputs, targets=targets)
101
102
103
                self.assertIsInstance(outputs, list)
                self.assertEqual(len(outputs), len(targets))
            for targets in invalid_targets:
104
                self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
105
106
107
108
109
110
111
112
113

    @require_torch
    @slow
    def test_torch_fill_mask_results(self):
        mandatory_keys = {"sequence", "score", "token"}
        valid_inputs = [
            "My name is <mask>",
            "The largest city in France is <mask>",
        ]
114
        valid_targets = [" Patrick", " Clara"]
115
        for model_name in self.large_models:
116
            unmasker = pipeline(
117
118
119
120
                task="fill-mask",
                model=model_name,
                tokenizer=model_name,
                framework="pt",
121
                top_k=2,
122
            )
Lysandre Debut's avatar
Lysandre Debut committed
123

124
            mono_result = unmasker(valid_inputs[0], targets=valid_targets)
Lysandre Debut's avatar
Lysandre Debut committed
125
126
127
128
129
130
            self.assertIsInstance(mono_result, list)
            self.assertIsInstance(mono_result[0], dict)

            for mandatory_key in mandatory_keys:
                self.assertIn(mandatory_key, mono_result[0])

131
            multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
Lysandre Debut's avatar
Lysandre Debut committed
132
133
134
135
            self.assertIsInstance(multi_result, list)
            self.assertIsInstance(multi_result[0], (dict, list))

            for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
136
137
138
139
140
                for r, e in zip(result, expected):
                    self.assertEqual(r["sequence"], e["sequence"])
                    self.assertEqual(r["token_str"], e["token_str"])
                    self.assertEqual(r["token"], e["token"])
                    self.assertAlmostEqual(r["score"], e["score"], places=3)
Lysandre Debut's avatar
Lysandre Debut committed
141
142
143
144
145
146
147
148

            if isinstance(multi_result[0], list):
                multi_result = multi_result[0]

            for result in multi_result:
                for key in mandatory_keys:
                    self.assertIn(key, result)

149
            self.assertRaises(Exception, unmasker, [None])
Lysandre Debut's avatar
Lysandre Debut committed
150
151

            valid_inputs = valid_inputs[:1]
152
            mono_result = unmasker(valid_inputs[0], targets=valid_targets)
Lysandre Debut's avatar
Lysandre Debut committed
153
154
155
156
157
158
            self.assertIsInstance(mono_result, list)
            self.assertIsInstance(mono_result[0], dict)

            for mandatory_key in mandatory_keys:
                self.assertIn(mandatory_key, mono_result[0])

159
            multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
Lysandre Debut's avatar
Lysandre Debut committed
160
161
162
163
            self.assertIsInstance(multi_result, list)
            self.assertIsInstance(multi_result[0], (dict, list))

            for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
164
165
166
167
168
                for r, e in zip(result, expected):
                    self.assertEqual(r["sequence"], e["sequence"])
                    self.assertEqual(r["token_str"], e["token_str"])
                    self.assertEqual(r["token"], e["token"])
                    self.assertAlmostEqual(r["score"], e["score"], places=3)
Lysandre Debut's avatar
Lysandre Debut committed
169
170
171
172
173
174
175
176

            if isinstance(multi_result[0], list):
                multi_result = multi_result[0]

            for result in multi_result:
                for key in mandatory_keys:
                    self.assertIn(key, result)

177
            self.assertRaises(Exception, unmasker, [None])
178
179
180
181
182
183
184
185
186

    @require_tf
    @slow
    def test_tf_fill_mask_results(self):
        mandatory_keys = {"sequence", "score", "token"}
        valid_inputs = [
            "My name is <mask>",
            "The largest city in France is <mask>",
        ]
187
        valid_targets = [" Patrick", " Clara"]
188
        for model_name in self.large_models:
189
            unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
Lysandre Debut's avatar
Lysandre Debut committed
190

191
            mono_result = unmasker(valid_inputs[0], targets=valid_targets)
Lysandre Debut's avatar
Lysandre Debut committed
192
193
194
195
196
197
            self.assertIsInstance(mono_result, list)
            self.assertIsInstance(mono_result[0], dict)

            for mandatory_key in mandatory_keys:
                self.assertIn(mandatory_key, mono_result[0])

198
            multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
Lysandre Debut's avatar
Lysandre Debut committed
199
200
201
202
            self.assertIsInstance(multi_result, list)
            self.assertIsInstance(multi_result[0], (dict, list))

            for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
203
204
205
206
207
                for r, e in zip(result, expected):
                    self.assertEqual(r["sequence"], e["sequence"])
                    self.assertEqual(r["token_str"], e["token_str"])
                    self.assertEqual(r["token"], e["token"])
                    self.assertAlmostEqual(r["score"], e["score"], places=3)
Lysandre Debut's avatar
Lysandre Debut committed
208
209
210
211
212
213
214
215

            if isinstance(multi_result[0], list):
                multi_result = multi_result[0]

            for result in multi_result:
                for key in mandatory_keys:
                    self.assertIn(key, result)

216
            self.assertRaises(Exception, unmasker, [None])
Lysandre Debut's avatar
Lysandre Debut committed
217
218

            valid_inputs = valid_inputs[:1]
219
            mono_result = unmasker(valid_inputs[0], targets=valid_targets)
Lysandre Debut's avatar
Lysandre Debut committed
220
221
222
223
224
225
            self.assertIsInstance(mono_result, list)
            self.assertIsInstance(mono_result[0], dict)

            for mandatory_key in mandatory_keys:
                self.assertIn(mandatory_key, mono_result[0])

226
            multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
Lysandre Debut's avatar
Lysandre Debut committed
227
228
229
230
            self.assertIsInstance(multi_result, list)
            self.assertIsInstance(multi_result[0], (dict, list))

            for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
231
232
233
234
235
                for r, e in zip(result, expected):
                    self.assertEqual(r["sequence"], e["sequence"])
                    self.assertEqual(r["token_str"], e["token_str"])
                    self.assertEqual(r["token"], e["token"])
                    self.assertAlmostEqual(r["score"], e["score"], places=3)
Lysandre Debut's avatar
Lysandre Debut committed
236
237
238
239
240
241
242
243

            if isinstance(multi_result[0], list):
                multi_result = multi_result[0]

            for result in multi_result:
                for key in mandatory_keys:
                    self.assertIn(key, result)

244
            self.assertRaises(Exception, unmasker, [None])