Unverified Commit b89a964d authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Moving `zero-shot-classification` pipeline to new testing. (#13299)

* Moving `zero-shot-classification` pipeline to new testing.

* Cleaning up old mixins.

* Fixing tests
`sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english` is
corrupted in PT.

* Adding warning.
parent cc27ac1a
...@@ -2,12 +2,15 @@ from typing import List, Union ...@@ -2,12 +2,15 @@ from typing import List, Union
import numpy as np import numpy as np
from ..file_utils import add_end_docstrings from ..file_utils import add_end_docstrings, is_torch_available
from ..tokenization_utils import TruncationStrategy from ..tokenization_utils import TruncationStrategy
from ..utils import logging from ..utils import logging
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
if is_torch_available():
import torch
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -85,23 +88,84 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -85,23 +88,84 @@ class ZeroShotClassificationPipeline(Pipeline):
hypothesis_template, hypothesis_template,
padding=True, padding=True,
add_special_tokens=True, add_special_tokens=True,
truncation=TruncationStrategy.ONLY_FIRST, truncation=TruncationStrategy.DO_NOT_TRUNCATE,
**kwargs **kwargs
): ):
""" """
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
""" """
sequence_pairs = self._args_parser(sequences, candidate_labels, hypothesis_template) sequence_pairs = self._args_parser(sequences, candidate_labels, hypothesis_template)
inputs = self.tokenizer( return_tensors = self.framework
sequence_pairs, if getattr(self.tokenizer, "pad_token", None) is None:
add_special_tokens=add_special_tokens, # XXX some tokenizers do not have a padding token, we use simple lists
return_tensors=self.framework, # and no padding then
padding=padding, logger.warning("The tokenizer {self.tokenizer} does not have a pad token, we're not running it as a batch")
truncation=truncation, padding = False
) inputs = []
for sequence_pair in sequence_pairs:
model_input = self.tokenizer(
text=sequence_pair[0],
text_pair=sequence_pair[1],
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
)
inputs.append(model_input)
else:
inputs = self.tokenizer(
sequence_pairs,
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
)
return inputs return inputs
def _forward(self, inputs, return_tensors=False):
"""
Internal framework specific forward dispatching
Args:
inputs: dict holding all the keyword arguments for required by the model forward method.
return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array
Returns:
Numpy array
"""
# Encode for forward
with self.device_placement():
if self.framework == "tf":
if isinstance(inputs, list):
predictions = []
for input_ in inputs:
prediction = self.model(input_.data, training=False)[0]
predictions.append(prediction)
else:
predictions = self.model(inputs.data, training=False)[0]
else:
with torch.no_grad():
if isinstance(inputs, list):
predictions = []
for input_ in inputs:
model_input = self.ensure_tensor_on_device(**input_)
prediction = self.model(**model_input)[0].cpu()
predictions.append(prediction)
else:
inputs = self.ensure_tensor_on_device(**inputs)
predictions = self.model(**inputs)[0].cpu()
if return_tensors:
return predictions
else:
if isinstance(predictions, list):
predictions = np.array([p.numpy() for p in predictions])
else:
predictions = predictions.numpy()
return predictions
def __call__( def __call__(
self, self,
sequences: Union[str, List[str]], sequences: Union[str, List[str]],
...@@ -151,6 +215,12 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -151,6 +215,12 @@ class ZeroShotClassificationPipeline(Pipeline):
sequences = [sequences] sequences = [sequences]
outputs = super().__call__(sequences, candidate_labels, hypothesis_template) outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
if isinstance(outputs, list):
# XXX: Some tokenizers cannot handle batching because they don't
# have pad_token, so outputs will be a list, however, because outputs
# is only n logits and sequence_length is not present anymore, we
# can recreate a tensor out of outputs.
outputs = np.array(outputs)
num_sequences = len(sequences) num_sequences = len(sequences)
candidate_labels = self._args_parser._parse_labels(candidate_labels) candidate_labels = self._args_parser._parse_labels(candidate_labels)
reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1)) reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))
......
...@@ -17,21 +17,9 @@ import logging ...@@ -17,21 +17,9 @@ import logging
import string import string
from abc import abstractmethod from abc import abstractmethod
from functools import lru_cache from functools import lru_cache
from typing import List, Optional from unittest import skipIf
from unittest import mock, skipIf
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer
from transformers import (
FEATURE_EXTRACTOR_MAPPING,
TOKENIZER_MAPPING,
AutoFeatureExtractor,
AutoTokenizer,
is_tf_available,
is_torch_available,
pipeline,
)
from transformers.file_utils import to_py_obj
from transformers.pipelines import Pipeline
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -189,228 +177,3 @@ class PipelineTestCaseMeta(type): ...@@ -189,228 +177,3 @@ class PipelineTestCaseMeta(type):
dct["test_small_model_tf"] = dct.get("test_small_model_tf", inner) dct["test_small_model_tf"] = dct.get("test_small_model_tf", inner)
return type.__new__(mcs, name, bases, dct) return type.__new__(mcs, name, bases, dct)
VALID_INPUTS = ["A simple string", ["list of strings"]]
@is_pipeline_test
class CustomInputPipelineCommonMixin:
pipeline_task = None
pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
small_models = [] # Models tested without the @slow decorator
large_models = [] # Models tested with the @slow decorator
valid_inputs = VALID_INPUTS # Some inputs which are valid to compare fast and slow tokenizers
def setUp(self) -> None:
if not is_tf_available() and not is_torch_available():
return # Currently no JAX pipelines
# Download needed checkpoints
models = self.small_models
if _run_slow_tests:
models = models + self.large_models
for model_name in models:
if is_torch_available():
pipeline(
self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
**self.pipeline_loading_kwargs,
)
if is_tf_available():
pipeline(
self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
**self.pipeline_loading_kwargs,
)
@require_torch
@slow
def test_pt_defaults(self):
pipeline(self.pipeline_task, framework="pt", **self.pipeline_loading_kwargs)
@require_tf
@slow
def test_tf_defaults(self):
pipeline(self.pipeline_task, framework="tf", **self.pipeline_loading_kwargs)
@require_torch
def test_torch_small(self):
for model_name in self.small_models:
pipe_small = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(pipe_small)
@require_tf
def test_tf_small(self):
for model_name in self.small_models:
pipe_small = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(pipe_small)
@require_torch
@slow
def test_torch_large(self):
for model_name in self.large_models:
pipe_large = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(pipe_large)
@require_tf
@slow
def test_tf_large(self):
for model_name in self.large_models:
pipe_large = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(pipe_large)
def _test_pipeline(self, pipe: Pipeline):
raise NotImplementedError
@require_torch
def test_compare_slow_fast_torch(self):
for model_name in self.small_models:
pipe_slow = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
use_fast=False,
**self.pipeline_loading_kwargs,
)
pipe_fast = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
use_fast=True,
**self.pipeline_loading_kwargs,
)
self._compare_slow_fast_pipelines(pipe_slow, pipe_fast, method="forward")
@require_tf
def test_compare_slow_fast_tf(self):
for model_name in self.small_models:
pipe_slow = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
use_fast=False,
**self.pipeline_loading_kwargs,
)
pipe_fast = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
use_fast=True,
**self.pipeline_loading_kwargs,
)
self._compare_slow_fast_pipelines(pipe_slow, pipe_fast, method="call")
def _compare_slow_fast_pipelines(self, pipe_slow: Pipeline, pipe_fast: Pipeline, method: str):
"""We check that the inputs to the models forward passes are identical for
slow and fast tokenizers.
"""
with mock.patch.object(
pipe_slow.model, method, wraps=getattr(pipe_slow.model, method)
) as mock_slow, mock.patch.object(
pipe_fast.model, method, wraps=getattr(pipe_fast.model, method)
) as mock_fast:
for inputs in self.valid_inputs:
if isinstance(inputs, dict):
inputs.update(self.pipeline_running_kwargs)
_ = pipe_slow(**inputs)
_ = pipe_fast(**inputs)
else:
_ = pipe_slow(inputs, **self.pipeline_running_kwargs)
_ = pipe_fast(inputs, **self.pipeline_running_kwargs)
mock_slow.assert_called()
mock_fast.assert_called()
self.assertEqual(len(mock_slow.call_args_list), len(mock_fast.call_args_list))
for mock_slow_call_args, mock_fast_call_args in zip(
mock_slow.call_args_list, mock_slow.call_args_list
):
slow_call_args, slow_call_kwargs = mock_slow_call_args
fast_call_args, fast_call_kwargs = mock_fast_call_args
slow_call_args, slow_call_kwargs = to_py_obj(slow_call_args), to_py_obj(slow_call_kwargs)
fast_call_args, fast_call_kwargs = to_py_obj(fast_call_args), to_py_obj(fast_call_kwargs)
self.assertEqual(slow_call_args, fast_call_args)
self.assertDictEqual(slow_call_kwargs, fast_call_kwargs)
@is_pipeline_test
class MonoInputPipelineCommonMixin(CustomInputPipelineCommonMixin):
"""A version of the CustomInputPipelineCommonMixin
with a predefined `_test_pipeline` method.
"""
mandatory_keys = {} # Keys which should be in the output
invalid_inputs = [None] # inputs which are not allowed
expected_multi_result: Optional[List] = None
expected_check_keys: Optional[List[str]] = None
def _test_pipeline(self, pipe: Pipeline):
self.assertIsNotNone(pipe)
mono_result = pipe(self.valid_inputs[0], **self.pipeline_running_kwargs)
self.assertIsInstance(mono_result, list)
self.assertIsInstance(mono_result[0], (dict, list))
if isinstance(mono_result[0], list):
mono_result = mono_result[0]
for key in self.mandatory_keys:
self.assertIn(key, mono_result[0])
multi_result = [pipe(input, **self.pipeline_running_kwargs) for input in self.valid_inputs]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))
if self.expected_multi_result is not None:
for result, expect in zip(multi_result, self.expected_multi_result):
for key in self.expected_check_keys or []:
self.assertEqual(
set([o[key] for o in result]),
set([o[key] for o in expect]),
)
if isinstance(multi_result[0], list):
multi_result = multi_result[0]
for result in multi_result:
for key in self.mandatory_keys:
self.assertIn(key, result)
self.assertRaises(Exception, pipe, self.invalid_inputs)
...@@ -13,39 +13,82 @@ ...@@ -13,39 +13,82 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
from copy import deepcopy
from transformers import (
from transformers.pipelines import Pipeline MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
from .test_pipelines_common import CustomInputPipelineCommonMixin Pipeline,
ZeroShotClassificationPipeline,
pipeline,
class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): )
pipeline_task = "zero-shot-classification" from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
small_models = ["sgugger/tiny-distilbert-classification"] # Models tested without the @slow decorator
large_models = ["roberta-large-mnli"] # Models tested with the @slow decorator from .test_pipelines_common import ANY, PipelineTestCaseMeta
valid_inputs = [
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics"]}, @is_pipeline_test
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics, public health"}, class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics", "public health"]}, model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
{"sequences": ["Who are you voting for in 2020?"], "candidate_labels": "politics"}, tf_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
{
"sequences": "Who are you voting for in 2020?", def run_pipeline_test(self, model, tokenizer, feature_extractor):
"candidate_labels": "politics", classifier = ZeroShotClassificationPipeline(model=model, tokenizer=tokenizer)
"hypothesis_template": "This text is about {}",
}, outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics")
] self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
def _test_scores_sum_to_one(self, result): outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics"])
sum = 0.0 self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
for score in result["scores"]:
sum += score outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics, public health")
self.assertAlmostEqual(sum, 1.0, places=5) self.assertEqual(
outputs, {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
def _test_entailment_id(self, zero_shot_classifier: Pipeline): )
self.assertAlmostEqual(sum(nested_simplify(outputs["scores"])), 1.0)
outputs = classifier("Who are you voting for in 2020?", candidate_labels=["politics", "public health"])
self.assertEqual(
outputs, {"sequence": ANY(str), "labels": [ANY(str), ANY(str)], "scores": [ANY(float), ANY(float)]}
)
self.assertAlmostEqual(sum(nested_simplify(outputs["scores"])), 1.0)
outputs = classifier(
"Who are you voting for in 2020?", candidate_labels="politics", hypothesis_template="This text is about {}"
)
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})
with self.assertRaises(ValueError):
classifier("", candidate_labels="politics")
with self.assertRaises(TypeError):
classifier(None, candidate_labels="politics")
with self.assertRaises(ValueError):
classifier("Who are you voting for in 2020?", candidate_labels="")
with self.assertRaises(TypeError):
classifier("Who are you voting for in 2020?", candidate_labels=None)
with self.assertRaises(ValueError):
classifier(
"Who are you voting for in 2020?",
candidate_labels="politics",
hypothesis_template="Not formatting template",
)
with self.assertRaises(AttributeError):
classifier(
"Who are you voting for in 2020?",
candidate_labels="politics",
hypothesis_template=None,
)
self.run_entailment_id(classifier)
def run_entailment_id(self, zero_shot_classifier: Pipeline):
config = zero_shot_classifier.model.config config = zero_shot_classifier.model.config
original_config = deepcopy(config) original_label2id = config.label2id
original_entailment = zero_shot_classifier.entailment_id
config.label2id = {"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2} config.label2id = {"LABEL_0": 0, "LABEL_1": 1, "LABEL_2": 2}
self.assertEqual(zero_shot_classifier.entailment_id, -1) self.assertEqual(zero_shot_classifier.entailment_id, -1)
...@@ -59,107 +102,105 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte ...@@ -59,107 +102,105 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
config.label2id = {"ENTAIL": 2, "NEUTRAL": 1, "CONTR": 0} config.label2id = {"ENTAIL": 2, "NEUTRAL": 1, "CONTR": 0}
self.assertEqual(zero_shot_classifier.entailment_id, 2) self.assertEqual(zero_shot_classifier.entailment_id, 2)
zero_shot_classifier.model.config = original_config zero_shot_classifier.model.config.label2id = original_label2id
self.assertEqual(original_entailment, zero_shot_classifier.entailment_id)
def _test_pipeline(self, zero_shot_classifier: Pipeline):
output_keys = {"sequence", "labels", "scores"} @require_torch
valid_mono_inputs = [ def test_small_model_pt(self):
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"}, zero_shot_classifier = pipeline(
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics"]}, "zero-shot-classification",
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics, public health"}, model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics", "public health"]}, framework="pt",
{"sequences": ["Who are you voting for in 2020?"], "candidate_labels": "politics"}, )
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)
self.assertEqual(
nested_simplify(outputs),
{
"sequence": "Who are you voting for in 2020?",
"labels": ["science", "public health", "politics"],
"scores": [0.333, 0.333, 0.333],
},
)
@require_tf
def test_small_model_tf(self):
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
framework="tf",
)
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)
self.assertEqual(
nested_simplify(outputs),
{
"sequence": "Who are you voting for in 2020?",
"labels": ["science", "public health", "politics"],
"scores": [0.333, 0.333, 0.333],
},
)
@slow
@require_torch
def test_large_model_pt(self):
zero_shot_classifier = pipeline("zero-shot-classification", model="roberta-large-mnli", framework="pt")
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)
self.assertEqual(
nested_simplify(outputs),
{ {
"sequences": "Who are you voting for in 2020?", "sequence": "Who are you voting for in 2020?",
"candidate_labels": "politics", "labels": ["politics", "public health", "science"],
"hypothesis_template": "This text is about {}", "scores": [0.976, 0.015, 0.009],
}, },
] )
valid_multi_input = { outputs = zero_shot_classifier(
"sequences": ["Who are you voting for in 2020?", "What is the capital of Spain?"], "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
"candidate_labels": "politics", candidate_labels=["machine learning", "statistics", "translation", "vision"],
} multi_label=True,
invalid_inputs = [ )
{"sequences": None, "candidate_labels": "politics"}, self.assertEqual(
{"sequences": "", "candidate_labels": "politics"}, nested_simplify(outputs),
{"sequences": "Who are you voting for in 2020?", "candidate_labels": None},
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ""},
{ {
"sequences": "Who are you voting for in 2020?", "sequence": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
"candidate_labels": "politics", "labels": ["translation", "machine learning", "vision", "statistics"],
"hypothesis_template": None, "scores": [0.817, 0.713, 0.018, 0.018],
}, },
)
@slow
@require_tf
def test_large_model_tf(self):
zero_shot_classifier = pipeline("zero-shot-classification", model="roberta-large-mnli", framework="tf")
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)
self.assertEqual(
nested_simplify(outputs),
{ {
"sequences": "Who are you voting for in 2020?", "sequence": "Who are you voting for in 2020?",
"candidate_labels": "politics", "labels": ["politics", "public health", "science"],
"hypothesis_template": "", "scores": [0.976, 0.015, 0.009],
}, },
)
outputs = zero_shot_classifier(
"The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
candidate_labels=["machine learning", "statistics", "translation", "vision"],
multi_label=True,
)
self.assertEqual(
nested_simplify(outputs),
{ {
"sequences": "Who are you voting for in 2020?", "sequence": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
"candidate_labels": "politics", "labels": ["translation", "machine learning", "vision", "statistics"],
"hypothesis_template": "Template without formatting syntax.", "scores": [0.817, 0.713, 0.018, 0.018],
}, },
] )
self.assertIsNotNone(zero_shot_classifier)
self._test_entailment_id(zero_shot_classifier)
for mono_input in valid_mono_inputs:
mono_result = zero_shot_classifier(**mono_input)
self.assertIsInstance(mono_result, dict)
if len(mono_result["labels"]) > 1:
self._test_scores_sum_to_one(mono_result)
for key in output_keys:
self.assertIn(key, mono_result)
multi_result = zero_shot_classifier(**valid_multi_input)
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], dict)
self.assertEqual(len(multi_result), len(valid_multi_input["sequences"]))
for result in multi_result:
for key in output_keys:
self.assertIn(key, result)
if len(result["labels"]) > 1:
self._test_scores_sum_to_one(result)
for bad_input in invalid_inputs:
self.assertRaises(Exception, zero_shot_classifier, **bad_input)
if zero_shot_classifier.model.name_or_path in self.large_models:
# We also check the outputs for the large models
inputs = [
{
"sequences": "Who are you voting for in 2020?",
"candidate_labels": ["politics", "public health", "science"],
},
{
"sequences": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
"candidate_labels": ["machine learning", "statistics", "translation", "vision"],
"multi_label": True,
},
]
expected_outputs = [
{
"sequence": "Who are you voting for in 2020?",
"labels": ["politics", "public health", "science"],
"scores": [0.975, 0.015, 0.008],
},
{
"sequence": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
"labels": ["translation", "machine learning", "vision", "statistics"],
"scores": [0.817, 0.712, 0.018, 0.017],
},
]
for input, expected_output in zip(inputs, expected_outputs):
output = zero_shot_classifier(**input)
for key in output:
if key == "scores":
for output_score, expected_score in zip(output[key], expected_output[key]):
self.assertAlmostEqual(output_score, expected_score, places=2)
else:
self.assertEqual(output[key], expected_output[key])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment