test_pipelines_zero_shot.py 14.8 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16
17
18
19
20
21
22
23

from transformers import (
    MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
    TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
    Pipeline,
    ZeroShotClassificationPipeline,
    pipeline,
)
24
from transformers.testing_utils import nested_simplify, require_tf, require_torch, slow
25

26
from .test_pipelines_common import ANY
27
28


29
class ZeroShotClassificationPipelineTests(unittest.TestCase):
30
31
32
    model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
    tf_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING

33
    def get_test_pipeline(self, model, tokenizer, processor):
34
35
36
37
        classifier = ZeroShotClassificationPipeline(
            model=model, tokenizer=tokenizer, candidate_labels=["polics", "health"]
        )
        return classifier, ["Who are you voting for in 2020?", "My stomach hurts."]
38

39
    def run_pipeline_test(self, classifier, _):
40
41
42
        outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics")
        self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})

43
44
45
46
        # No kwarg
        outputs = classifier("Who are you voting for in 2020?", ["politics"])
        self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"])
        self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})

        outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics, public health")
        self.assertEqual(
            outputs, {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
        )
        self.assertAlmostEqual(sum(nested_simplify(outputs["scores"])), 1.0)

        outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics", "public health"])
        self.assertEqual(
            outputs, {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
        )
        self.assertAlmostEqual(sum(nested_simplify(outputs["scores"])), 1.0)

        outputs = classifier(
            "Who are you voting for in 2020?", candidate_labels="politics", hypothesis_template="This text is about {}"
        )
        self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        # https://github.com/huggingface/transformers/issues/13846
        outputs = classifier(["I am happy"], ["positive", "negative"])
        self.assertEqual(
            outputs,
            [
                {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
                for i in range(1)
            ],
        )
        outputs = classifier(["I am happy", "I am sad"], ["positive", "negative"])
        self.assertEqual(
            outputs,
            [
                {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
                for i in range(2)
            ],
        )

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        with self.assertRaises(ValueError):
            classifier("", candidate_labels="politics")

        with self.assertRaises(TypeError):
            classifier(None, candidate_labels="politics")

        with self.assertRaises(ValueError):
            classifier("Who are you voting for in 2020?", candidate_labels="")

        with self.assertRaises(TypeError):
            classifier("Who are you voting for in 2020?", candidate_labels=None)

        with self.assertRaises(ValueError):
            classifier(
                "Who are you voting for in 2020?",
                candidate_labels="politics",
                hypothesis_template="Not formatting template",
            )

        with self.assertRaises(AttributeError):
            classifier(
                "Who are you voting for in 2020?",
                candidate_labels="politics",
                hypothesis_template=None,
            )

        self.run_entailment_id(classifier)

    def run_entailment_id(self, zero_shot_classifier: Pipeline):
114
        config = zero_shot_classifier.model.config
115
116
        original_label2id = config.label2id
        original_entailment = zero_shot_classifier.entailment_id
117
118

        config.label2id = {"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2}
119
        self.assertEqual(zero_shot_classifier.entailment_id, -1)
120
121

        config.label2id = {"entailment": 0, "neutral": 1, "contradiction": 2}
122
        self.assertEqual(zero_shot_classifier.entailment_id, 0)
123
124

        config.label2id = {"ENTAIL": 0, "NON-ENTAIL": 1}
125
        self.assertEqual(zero_shot_classifier.entailment_id, 0)
126
127

        config.label2id = {"ENTAIL": 2, "NEUTRAL": 1, "CONTR": 0}
128
        self.assertEqual(zero_shot_classifier.entailment_id, 2)
129

130
131
132
        zero_shot_classifier.model.config.label2id = original_label2id
        self.assertEqual(original_entailment, zero_shot_classifier.entailment_id)

Nicolas Patry's avatar
Nicolas Patry committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    @require_torch
    def test_truncation(self):
        zero_shot_classifier = pipeline(
            "zero-shot-classification",
            model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
            framework="pt",
        )
        # There was a regression in 4.10 for this
        # Adding a test so we don't make the mistake again.
        # https://github.com/huggingface/transformers/issues/13381#issuecomment-912343499
        zero_shot_classifier(
            "Who are you voting for in 2020?" * 100, candidate_labels=["politics", "public health", "science"]
        )

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    @require_torch
    def test_small_model_pt(self):
        zero_shot_classifier = pipeline(
            "zero-shot-classification",
            model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
            framework="pt",
        )
        outputs = zero_shot_classifier(
            "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
        )

        self.assertEqual(
            nested_simplify(outputs),
            {
                "sequence": "Who are you voting for in 2020?",
                "labels": ["science", "public health", "politics"],
                "scores": [0.333, 0.333, 0.333],
            },
        )

    @require_tf
    def test_small_model_tf(self):
        zero_shot_classifier = pipeline(
            "zero-shot-classification",
            model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
            framework="tf",
        )
        outputs = zero_shot_classifier(
            "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
        )

        self.assertEqual(
            nested_simplify(outputs),
            {
                "sequence": "Who are you voting for in 2020?",
                "labels": ["science", "public health", "politics"],
                "scores": [0.333, 0.333, 0.333],
            },
        )

    @slow
    @require_torch
    def test_large_model_pt(self):
        zero_shot_classifier = pipeline("zero-shot-classification", model="roberta-large-mnli", framework="pt")
        outputs = zero_shot_classifier(
            "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
        )

        self.assertEqual(
            nested_simplify(outputs),
197
            {
198
199
200
                "sequence": "Who are you voting for in 2020?",
                "labels": ["politics", "public health", "science"],
                "scores": [0.976, 0.015, 0.009],
201
            },
202
203
        )
        outputs = zero_shot_classifier(
Sylvain Gugger's avatar
Sylvain Gugger committed
204
205
206
207
208
209
210
211
212
213
214
215
            "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks"
            " in an encoder-decoder configuration. The best performing models also connect the encoder and decoder"
            " through an attention mechanism. We propose a new simple network architecture, the Transformer, based"
            " solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two"
            " machine translation tasks show these models to be superior in quality while being more parallelizable"
            " and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014"
            " English-to-German translation task, improving over the existing best results, including ensembles by"
            " over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new"
            " single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small"
            " fraction of the training costs of the best models from the literature. We show that the Transformer"
            " generalizes well to other tasks by applying it successfully to English constituency parsing both with"
            " large and limited training data.",
216
217
218
219
220
            candidate_labels=["machine learning", "statistics", "translation", "vision"],
            multi_label=True,
        )
        self.assertEqual(
            nested_simplify(outputs),
221
            {
Sylvain Gugger's avatar
Sylvain Gugger committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
                "sequence": (
                    "The dominant sequence transduction models are based on complex recurrent or convolutional neural"
                    " networks in an encoder-decoder configuration. The best performing models also connect the"
                    " encoder and decoder through an attention mechanism. We propose a new simple network"
                    " architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence"
                    " and convolutions entirely. Experiments on two machine translation tasks show these models to be"
                    " superior in quality while being more parallelizable and requiring significantly less time to"
                    " train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task,"
                    " improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014"
                    " English-to-French translation task, our model establishes a new single-model state-of-the-art"
                    " BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training"
                    " costs of the best models from the literature. We show that the Transformer generalizes well to"
                    " other tasks by applying it successfully to English constituency parsing both with large and"
                    " limited training data."
                ),
237
238
                "labels": ["translation", "machine learning", "vision", "statistics"],
                "scores": [0.817, 0.713, 0.018, 0.018],
239
            },
240
241
242
243
244
245
246
247
248
249
250
251
        )

    @slow
    @require_tf
    def test_large_model_tf(self):
        zero_shot_classifier = pipeline("zero-shot-classification", model="roberta-large-mnli", framework="tf")
        outputs = zero_shot_classifier(
            "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
        )

        self.assertEqual(
            nested_simplify(outputs),
252
            {
253
254
255
                "sequence": "Who are you voting for in 2020?",
                "labels": ["politics", "public health", "science"],
                "scores": [0.976, 0.015, 0.009],
256
            },
257
258
        )
        outputs = zero_shot_classifier(
Sylvain Gugger's avatar
Sylvain Gugger committed
259
260
261
262
263
264
265
266
267
268
269
270
            "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks"
            " in an encoder-decoder configuration. The best performing models also connect the encoder and decoder"
            " through an attention mechanism. We propose a new simple network architecture, the Transformer, based"
            " solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two"
            " machine translation tasks show these models to be superior in quality while being more parallelizable"
            " and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014"
            " English-to-German translation task, improving over the existing best results, including ensembles by"
            " over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new"
            " single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small"
            " fraction of the training costs of the best models from the literature. We show that the Transformer"
            " generalizes well to other tasks by applying it successfully to English constituency parsing both with"
            " large and limited training data.",
271
272
273
274
275
            candidate_labels=["machine learning", "statistics", "translation", "vision"],
            multi_label=True,
        )
        self.assertEqual(
            nested_simplify(outputs),
276
            {
Sylvain Gugger's avatar
Sylvain Gugger committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
                "sequence": (
                    "The dominant sequence transduction models are based on complex recurrent or convolutional neural"
                    " networks in an encoder-decoder configuration. The best performing models also connect the"
                    " encoder and decoder through an attention mechanism. We propose a new simple network"
                    " architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence"
                    " and convolutions entirely. Experiments on two machine translation tasks show these models to be"
                    " superior in quality while being more parallelizable and requiring significantly less time to"
                    " train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task,"
                    " improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014"
                    " English-to-French translation task, our model establishes a new single-model state-of-the-art"
                    " BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training"
                    " costs of the best models from the literature. We show that the Transformer generalizes well to"
                    " other tasks by applying it successfully to English constituency parsing both with large and"
                    " limited training data."
                ),
292
293
                "labels": ["translation", "machine learning", "vision", "statistics"],
                "scores": [0.817, 0.713, 0.018, 0.018],
294
            },
295
        )