Unverified Commit 84caa233 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fix the behaviour of DefaultArgumentHandler (removing it). (#8180)

* Some work to fix the behaviour of DefaultArgumentHandler by removing it.

* Fixing specific pipelines argument checking.
parent 00cc2d1d
......@@ -23,9 +23,8 @@ import uuid
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from itertools import chain
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from uuid import UUID
import numpy as np
......@@ -185,57 +184,6 @@ class ArgumentHandler(ABC):
raise NotImplementedError()
class DefaultArgumentHandler(ArgumentHandler):
"""
Default argument parser handling parameters for each :class:`~transformers.pipelines.Pipeline`.
"""
@staticmethod
def handle_kwargs(kwargs: Dict) -> List:
if len(kwargs) == 1:
output = list(kwargs.values())
else:
output = list(chain(kwargs.values()))
return DefaultArgumentHandler.handle_args(output)
@staticmethod
def handle_args(args: Sequence[Any]) -> List[str]:
# Only one argument, let's do case by case
if len(args) == 1:
if isinstance(args[0], str):
return [args[0]]
elif not isinstance(args[0], list):
return list(args)
else:
return args[0]
# Multiple arguments (x1, x2, ...)
elif len(args) > 1:
if all([isinstance(arg, str) for arg in args]):
return list(args)
# If not instance of list, then it should instance of iterable
elif isinstance(args, Iterable):
return list(chain.from_iterable(chain(args)))
else:
raise ValueError(
"Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args))
)
else:
return []
def __call__(self, *args, **kwargs):
if len(kwargs) > 0 and len(args) > 0:
raise ValueError("Pipeline cannot handle mixed args and kwargs")
if len(kwargs) > 0:
return DefaultArgumentHandler.handle_kwargs(kwargs)
else:
return DefaultArgumentHandler.handle_args(args)
class PipelineDataFormat:
"""
Base class for all the pipeline supported data format both for reading and writing. Supported data formats
......@@ -574,7 +522,6 @@ class Pipeline(_ScikitCompat):
self.framework = framework
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
self.binary_output = binary_output
self._args_parser = args_parser or DefaultArgumentHandler()
# Special handling
if self.framework == "pt" and self.device.type == "cuda":
......@@ -669,12 +616,11 @@ class Pipeline(_ScikitCompat):
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
)
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
"""
Parse arguments and tokenize
"""
# Parse arguments
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
......@@ -836,7 +782,7 @@ class TextGenerationPipeline(Pipeline):
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
"""
Parse arguments and tokenize
"""
......@@ -845,7 +791,6 @@ class TextGenerationPipeline(Pipeline):
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
else:
tokenizer_kwargs = {}
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
......@@ -858,7 +803,7 @@ class TextGenerationPipeline(Pipeline):
def __call__(
self,
*args,
text_inputs,
return_tensors=False,
return_text=True,
clean_up_tokenization_spaces=False,
......@@ -890,7 +835,6 @@ class TextGenerationPipeline(Pipeline):
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
-- The token ids of the generated text.
"""
text_inputs = self._args_parser(*args)
results = []
for prompt_text in text_inputs:
......@@ -1094,7 +1038,8 @@ class ZeroShotClassificationPipeline(Pipeline):
"""
def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
super().__init__(*args, args_parser=args_parser, **kwargs)
super().__init__(*args, **kwargs)
self._args_parser = args_parser
if self.entailment_id == -1:
logger.warning(
"Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
......@@ -1108,13 +1053,15 @@ class ZeroShotClassificationPipeline(Pipeline):
return ind
return -1
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
def _parse_and_tokenize(
self, sequences, candidal_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs
):
"""
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
"""
inputs = self._args_parser(*args, **kwargs)
sequence_pairs = self._args_parser(sequences, candidal_labels, hypothesis_template)
inputs = self.tokenizer(
inputs,
sequence_pairs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
......@@ -1123,7 +1070,13 @@ class ZeroShotClassificationPipeline(Pipeline):
return inputs
def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False):
def __call__(
self,
sequences: Union[str, List[str]],
candidate_labels,
hypothesis_template="This example is {}.",
multi_class=False,
):
"""
Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline`
documentation for more information.
......@@ -1154,8 +1107,11 @@ class ZeroShotClassificationPipeline(Pipeline):
- **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
"""
if sequences and isinstance(sequences, str):
sequences = [sequences]
outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
num_sequences = 1 if isinstance(sequences, str) else len(sequences)
num_sequences = len(sequences)
candidate_labels = self._args_parser._parse_labels(candidate_labels)
reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))
......@@ -1425,12 +1381,12 @@ class TokenClassificationPipeline(Pipeline):
self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities
def __call__(self, *args, **kwargs):
def __call__(self, inputs: Union[str, List[str]], **kwargs):
"""
Classify each token of the text(s) given as inputs.
Args:
args (:obj:`str` or :obj:`List[str]`):
inputs (:obj:`str` or :obj:`List[str]`):
One or several texts (or one list of texts) for token classification.
Return:
......@@ -1444,7 +1400,8 @@ class TokenClassificationPipeline(Pipeline):
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
corresponding token in the sentence.
"""
inputs = self._args_parser(*args, **kwargs)
if isinstance(inputs, str):
inputs = [inputs]
answers = []
for sentence in inputs:
......@@ -1659,12 +1616,12 @@ class QuestionAnsweringPipeline(Pipeline):
tokenizer=tokenizer,
modelcard=modelcard,
framework=framework,
args_parser=QuestionAnsweringArgumentHandler(),
device=device,
task=task,
**kwargs,
)
self._args_parser = QuestionAnsweringArgumentHandler()
self.check_model_type(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
)
......@@ -2489,12 +2446,11 @@ class ConversationalPipeline(Pipeline):
else:
return output
def _parse_and_tokenize(self, *args, **kwargs):
def _parse_and_tokenize(self, inputs, **kwargs):
"""
Parse arguments and tokenize, adding an EOS token at the end of the user input
"""
# Parse arguments
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
for input in inputs:
input.append(self.tokenizer.eos_token_id)
......
import unittest
from typing import List, Optional
from transformers import is_tf_available, is_torch_available, pipeline
from transformers.pipelines import DefaultArgumentHandler, Pipeline
# from transformers.pipelines import DefaultArgumentHandler, Pipeline
from transformers.pipelines import Pipeline
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
......@@ -200,74 +201,74 @@ class MonoInputPipelineCommonMixin:
self.assertRaises(Exception, nlp, self.invalid_inputs)
@is_pipeline_test
class DefaultArgumentHandlerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.handler = DefaultArgumentHandler()
def test_kwargs_x(self):
mono_data = {"X": "This is a sample input"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
def test_kwargs_data(self):
mono_data = {"data": "This is a sample input"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
def test_multi_kwargs(self):
mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 2)
multi_data = {
"data": ["This is a sample input", "This is a second sample input"],
"test": ["This is a sample input 2", "This is a second sample input 2"],
}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 4)
def test_args(self):
mono_data = "This is a sample input"
mono_args = self.handler(mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
mono_data = ["This is a sample input"]
mono_args = self.handler(mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(*multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
# @is_pipeline_test
# class DefaultArgumentHandlerTestCase(unittest.TestCase):
# def setUp(self) -> None:
# self.handler = DefaultArgumentHandler()
#
# def test_kwargs_x(self):
# mono_data = {"X": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_kwargs_data(self):
# mono_data = {"data": "This is a sample input"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# def test_multi_kwargs(self):
# mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
# mono_args = self.handler(**mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 2)
#
# multi_data = {
# "data": ["This is a sample input", "This is a second sample input"],
# "test": ["This is a sample input 2", "This is a second sample input 2"],
# }
# multi_args = self.handler(**multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 4)
#
# def test_args(self):
# mono_data = "This is a sample input"
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# mono_data = ["This is a sample input"]
# mono_args = self.handler(mono_data)
#
# self.assertTrue(isinstance(mono_args, list))
# self.assertEqual(len(mono_args), 1)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
#
# multi_data = ["This is a sample input", "This is a second sample input"]
# multi_args = self.handler(*multi_data)
#
# self.assertTrue(isinstance(multi_args, list))
# self.assertEqual(len(multi_args), 2)
import unittest
import pytest
from transformers import pipeline
from transformers.testing_utils import require_tf, require_torch, slow
......@@ -37,7 +39,7 @@ EXPECTED_FILL_MASK_TARGET_RESULT = [
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "fill-mask"
pipeline_loading_kwargs = {"topk": 2}
pipeline_loading_kwargs = {"top_k": 2}
small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator
large_models = ["distilroberta-base"] # Models tested with the @slow decorator
mandatory_keys = {"sequence", "score", "token"}
......@@ -51,6 +53,28 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
]
expected_check_keys = ["sequence"]
@require_torch
def test_torch_topk_deprecation(self):
# At pipeline initialization only it was not enabled at pipeline
# call site before
with pytest.warns(FutureWarning, match=r".*use `top_k`.*"):
pipeline(task="fill-mask", model=self.small_models[0], topk=1)
@require_torch
def test_torch_fill_mask(self):
valid_inputs = "My name is <mask>"
nlp = pipeline(task="fill-mask", model=self.small_models[0])
outputs = nlp(valid_inputs)
self.assertIsInstance(outputs, list)
# This passes
outputs = nlp(valid_inputs, targets=[" Patrick", " Clara"])
self.assertIsInstance(outputs, list)
# This used to fail with `cannot mix args and kwargs`
outputs = nlp(valid_inputs, something=False)
self.assertIsInstance(outputs, list)
@require_torch
def test_torch_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
......@@ -94,7 +118,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
model=model_name,
tokenizer=model_name,
framework="pt",
topk=2,
top_k=2,
)
mono_result = nlp(valid_inputs[0], targets=valid_targets)
......
......@@ -17,7 +17,7 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
sum = 0.0
for score in result["scores"]:
sum += score
self.assertAlmostEqual(sum, 1.0)
self.assertAlmostEqual(sum, 1.0, places=5)
def _test_entailment_id(self, nlp: Pipeline):
config = nlp.model.config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment