Unverified Commit b89a964d authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Moving `zero-shot-classification` pipeline to new testing. (#13299)

* Moving `zero-shot-classification` pipeline to new testing.

* Cleaning up old mixins.

* Fixing tests
`sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english` is
corrupted in PT.

* Adding warning.
parent cc27ac1a
......@@ -2,12 +2,15 @@ from typing import List, Union
import numpy as np
from ..file_utils import add_end_docstrings
from ..file_utils import add_end_docstrings, is_torch_available
from ..tokenization_utils import TruncationStrategy
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
......@@ -85,23 +88,84 @@ class ZeroShotClassificationPipeline(Pipeline):
hypothesis_template,
padding=True,
add_special_tokens=True,
truncation=TruncationStrategy.ONLY_FIRST,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
**kwargs
):
"""
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
"""
sequence_pairs = self._args_parser(sequences, candidate_labels, hypothesis_template)
inputs = self.tokenizer(
sequence_pairs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
truncation=truncation,
)
return_tensors = self.framework
if getattr(self.tokenizer, "pad_token", None) is None:
# XXX some tokenizers do not have a padding token, we use simple lists
# and no padding then
logger.warning("The tokenizer {self.tokenizer} does not have a pad token, we're not running it as a batch")
padding = False
inputs = []
for sequence_pair in sequence_pairs:
model_input = self.tokenizer(
text=sequence_pair[0],
text_pair=sequence_pair[1],
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
)
inputs.append(model_input)
else:
inputs = self.tokenizer(
sequence_pairs,
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
)
return inputs
def _forward(self, inputs, return_tensors=False):
"""
Internal framework specific forward dispatching
Args:
inputs: dict holding all the keyword arguments for required by the model forward method.
return_tensors: Whether to return native framework (pt/tf) tensors rather than numpy array
Returns:
Numpy array
"""
# Encode for forward
with self.device_placement():
if self.framework == "tf":
if isinstance(inputs, list):
predictions = []
for input_ in inputs:
prediction = self.model(input_.data, training=False)[0]
predictions.append(prediction)
else:
predictions = self.model(inputs.data, training=False)[0]
else:
with torch.no_grad():
if isinstance(inputs, list):
predictions = []
for input_ in inputs:
model_input = self.ensure_tensor_on_device(**input_)
prediction = self.model(**model_input)[0].cpu()
predictions.append(prediction)
else:
inputs = self.ensure_tensor_on_device(**inputs)
predictions = self.model(**inputs)[0].cpu()
if return_tensors:
return predictions
else:
if isinstance(predictions, list):
predictions = np.array([p.numpy() for p in predictions])
else:
predictions = predictions.numpy()
return predictions
def __call__(
self,
sequences: Union[str, List[str]],
......@@ -151,6 +215,12 @@ class ZeroShotClassificationPipeline(Pipeline):
sequences = [sequences]
outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
if isinstance(outputs, list):
# XXX: Some tokenizers cannot handle batching because they don't
# have pad_token, so outputs will be a list, however, because outputs
# is only n logits and sequence_length is not present anymore, we
# can recreate a tensor out of outputs.
outputs = np.array(outputs)
num_sequences = len(sequences)
candidate_labels = self._args_parser._parse_labels(candidate_labels)
reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))
......
......@@ -17,21 +17,9 @@ import logging
import string
from abc import abstractmethod
from functools import lru_cache
from typing import List, Optional
from unittest import mock, skipIf
from transformers import (
FEATURE_EXTRACTOR_MAPPING,
TOKENIZER_MAPPING,
AutoFeatureExtractor,
AutoTokenizer,
is_tf_available,
is_torch_available,
pipeline,
)
from transformers.file_utils import to_py_obj
from transformers.pipelines import Pipeline
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
from unittest import skipIf
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer
logger = logging.getLogger(__name__)
......@@ -189,228 +177,3 @@ class PipelineTestCaseMeta(type):
dct["test_small_model_tf"] = dct.get("test_small_model_tf", inner)
return type.__new__(mcs, name, bases, dct)
VALID_INPUTS = ["A simple string", ["list of strings"]]
@is_pipeline_test
class CustomInputPipelineCommonMixin:
pipeline_task = None
pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
small_models = [] # Models tested without the @slow decorator
large_models = [] # Models tested with the @slow decorator
valid_inputs = VALID_INPUTS # Some inputs which are valid to compare fast and slow tokenizers
def setUp(self) -> None:
if not is_tf_available() and not is_torch_available():
return # Currently no JAX pipelines
# Download needed checkpoints
models = self.small_models
if _run_slow_tests:
models = models + self.large_models
for model_name in models:
if is_torch_available():
pipeline(
self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
**self.pipeline_loading_kwargs,
)
if is_tf_available():
pipeline(
self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
**self.pipeline_loading_kwargs,
)
@require_torch
@slow
def test_pt_defaults(self):
pipeline(self.pipeline_task, framework="pt", **self.pipeline_loading_kwargs)
@require_tf
@slow
def test_tf_defaults(self):
pipeline(self.pipeline_task, framework="tf", **self.pipeline_loading_kwargs)
@require_torch
def test_torch_small(self):
for model_name in self.small_models:
pipe_small = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(pipe_small)
@require_tf
def test_tf_small(self):
for model_name in self.small_models:
pipe_small = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(pipe_small)
@require_torch
@slow
def test_torch_large(self):
for model_name in self.large_models:
pipe_large = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(pipe_large)
@require_tf
@slow
def test_tf_large(self):
for model_name in self.large_models:
pipe_large = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
**self.pipeline_loading_kwargs,
)
self._test_pipeline(pipe_large)
def _test_pipeline(self, pipe: Pipeline):
raise NotImplementedError
@require_torch
def test_compare_slow_fast_torch(self):
for model_name in self.small_models:
pipe_slow = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
use_fast=False,
**self.pipeline_loading_kwargs,
)
pipe_fast = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="pt",
use_fast=True,
**self.pipeline_loading_kwargs,
)
self._compare_slow_fast_pipelines(pipe_slow, pipe_fast, method="forward")
@require_tf
def test_compare_slow_fast_tf(self):
for model_name in self.small_models:
pipe_slow = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
use_fast=False,
**self.pipeline_loading_kwargs,
)
pipe_fast = pipeline(
task=self.pipeline_task,
model=model_name,
tokenizer=model_name,
framework="tf",
use_fast=True,
**self.pipeline_loading_kwargs,
)
self._compare_slow_fast_pipelines(pipe_slow, pipe_fast, method="call")
def _compare_slow_fast_pipelines(self, pipe_slow: Pipeline, pipe_fast: Pipeline, method: str):
"""We check that the inputs to the models forward passes are identical for
slow and fast tokenizers.
"""
with mock.patch.object(
pipe_slow.model, method, wraps=getattr(pipe_slow.model, method)
) as mock_slow, mock.patch.object(
pipe_fast.model, method, wraps=getattr(pipe_fast.model, method)
) as mock_fast:
for inputs in self.valid_inputs:
if isinstance(inputs, dict):
inputs.update(self.pipeline_running_kwargs)
_ = pipe_slow(**inputs)
_ = pipe_fast(**inputs)
else:
_ = pipe_slow(inputs, **self.pipeline_running_kwargs)
_ = pipe_fast(inputs, **self.pipeline_running_kwargs)
mock_slow.assert_called()
mock_fast.assert_called()
self.assertEqual(len(mock_slow.call_args_list), len(mock_fast.call_args_list))
for mock_slow_call_args, mock_fast_call_args in zip(
mock_slow.call_args_list, mock_slow.call_args_list
):
slow_call_args, slow_call_kwargs = mock_slow_call_args
fast_call_args, fast_call_kwargs = mock_fast_call_args
slow_call_args, slow_call_kwargs = to_py_obj(slow_call_args), to_py_obj(slow_call_kwargs)
fast_call_args, fast_call_kwargs = to_py_obj(fast_call_args), to_py_obj(fast_call_kwargs)
self.assertEqual(slow_call_args, fast_call_args)
self.assertDictEqual(slow_call_kwargs, fast_call_kwargs)
@is_pipeline_test
class MonoInputPipelineCommonMixin(CustomInputPipelineCommonMixin):
"""A version of the CustomInputPipelineCommonMixin
with a predefined `_test_pipeline` method.
"""
mandatory_keys = {} # Keys which should be in the output
invalid_inputs = [None] # inputs which are not allowed
expected_multi_result: Optional[List] = None
expected_check_keys: Optional[List[str]] = None
def _test_pipeline(self, pipe: Pipeline):
self.assertIsNotNone(pipe)
mono_result = pipe(self.valid_inputs[0], **self.pipeline_running_kwargs)
self.assertIsInstance(mono_result, list)
self.assertIsInstance(mono_result[0], (dict, list))
if isinstance(mono_result[0], list):
mono_result = mono_result[0]
for key in self.mandatory_keys:
self.assertIn(key, mono_result[0])
multi_result = [pipe(input, **self.pipeline_running_kwargs) for input in self.valid_inputs]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))
if self.expected_multi_result is not None:
for result, expect in zip(multi_result, self.expected_multi_result):
for key in self.expected_check_keys or []:
self.assertEqual(
set([o[key] for o in result]),
set([o[key] for o in expect]),
)
if isinstance(multi_result[0], list):
multi_result = multi_result[0]
for result in multi_result:
for key in self.mandatory_keys:
self.assertIn(key, result)
self.assertRaises(Exception, pipe, self.invalid_inputs)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment