"doc/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "ee9278b22da7cfdc21c7af7d80da5071b7281490"
Unverified Commit 84caa233 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fix the behaviour of DefaultArgumentHandler (removing it). (#8180)

* Some work to fix the behaviour of DefaultArgumentHandler by removing it.

* Fixing specific pipelines argument checking.
parent 00cc2d1d
...@@ -23,9 +23,8 @@ import uuid ...@@ -23,9 +23,8 @@ import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from itertools import chain
from os.path import abspath, exists from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from uuid import UUID from uuid import UUID
import numpy as np import numpy as np
...@@ -185,57 +184,6 @@ class ArgumentHandler(ABC): ...@@ -185,57 +184,6 @@ class ArgumentHandler(ABC):
raise NotImplementedError() raise NotImplementedError()
class DefaultArgumentHandler(ArgumentHandler):
"""
Default argument parser handling parameters for each :class:`~transformers.pipelines.Pipeline`.
"""
@staticmethod
def handle_kwargs(kwargs: Dict) -> List:
if len(kwargs) == 1:
output = list(kwargs.values())
else:
output = list(chain(kwargs.values()))
return DefaultArgumentHandler.handle_args(output)
@staticmethod
def handle_args(args: Sequence[Any]) -> List[str]:
# Only one argument, let's do case by case
if len(args) == 1:
if isinstance(args[0], str):
return [args[0]]
elif not isinstance(args[0], list):
return list(args)
else:
return args[0]
# Multiple arguments (x1, x2, ...)
elif len(args) > 1:
if all([isinstance(arg, str) for arg in args]):
return list(args)
# If not instance of list, then it should instance of iterable
elif isinstance(args, Iterable):
return list(chain.from_iterable(chain(args)))
else:
raise ValueError(
"Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args))
)
else:
return []
def __call__(self, *args, **kwargs):
if len(kwargs) > 0 and len(args) > 0:
raise ValueError("Pipeline cannot handle mixed args and kwargs")
if len(kwargs) > 0:
return DefaultArgumentHandler.handle_kwargs(kwargs)
else:
return DefaultArgumentHandler.handle_args(args)
class PipelineDataFormat: class PipelineDataFormat:
""" """
Base class for all the pipeline supported data format both for reading and writing. Supported data formats Base class for all the pipeline supported data format both for reading and writing. Supported data formats
...@@ -574,7 +522,6 @@ class Pipeline(_ScikitCompat): ...@@ -574,7 +522,6 @@ class Pipeline(_ScikitCompat):
self.framework = framework self.framework = framework
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device)) self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
self.binary_output = binary_output self.binary_output = binary_output
self._args_parser = args_parser or DefaultArgumentHandler()
# Special handling # Special handling
if self.framework == "pt" and self.device.type == "cuda": if self.framework == "pt" and self.device.type == "cuda":
...@@ -669,12 +616,11 @@ class Pipeline(_ScikitCompat): ...@@ -669,12 +616,11 @@ class Pipeline(_ScikitCompat):
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}", f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
) )
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs): def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
""" """
Parse arguments and tokenize Parse arguments and tokenize
""" """
# Parse arguments # Parse arguments
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer( inputs = self.tokenizer(
inputs, inputs,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
...@@ -836,7 +782,7 @@ class TextGenerationPipeline(Pipeline): ...@@ -836,7 +782,7 @@ class TextGenerationPipeline(Pipeline):
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs): def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
""" """
Parse arguments and tokenize Parse arguments and tokenize
""" """
...@@ -845,7 +791,6 @@ class TextGenerationPipeline(Pipeline): ...@@ -845,7 +791,6 @@ class TextGenerationPipeline(Pipeline):
tokenizer_kwargs = {"add_space_before_punct_symbol": True} tokenizer_kwargs = {"add_space_before_punct_symbol": True}
else: else:
tokenizer_kwargs = {} tokenizer_kwargs = {}
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer( inputs = self.tokenizer(
inputs, inputs,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
...@@ -858,7 +803,7 @@ class TextGenerationPipeline(Pipeline): ...@@ -858,7 +803,7 @@ class TextGenerationPipeline(Pipeline):
def __call__( def __call__(
self, self,
*args, text_inputs,
return_tensors=False, return_tensors=False,
return_text=True, return_text=True,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
...@@ -890,7 +835,6 @@ class TextGenerationPipeline(Pipeline): ...@@ -890,7 +835,6 @@ class TextGenerationPipeline(Pipeline):
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
-- The token ids of the generated text. -- The token ids of the generated text.
""" """
text_inputs = self._args_parser(*args)
results = [] results = []
for prompt_text in text_inputs: for prompt_text in text_inputs:
...@@ -1094,7 +1038,8 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -1094,7 +1038,8 @@ class ZeroShotClassificationPipeline(Pipeline):
""" """
def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs): def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
super().__init__(*args, args_parser=args_parser, **kwargs) super().__init__(*args, **kwargs)
self._args_parser = args_parser
if self.entailment_id == -1: if self.entailment_id == -1:
logger.warning( logger.warning(
"Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to " "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
...@@ -1108,13 +1053,15 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -1108,13 +1053,15 @@ class ZeroShotClassificationPipeline(Pipeline):
return ind return ind
return -1 return -1
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs): def _parse_and_tokenize(
self, sequences, candidal_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs
):
""" """
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
""" """
inputs = self._args_parser(*args, **kwargs) sequence_pairs = self._args_parser(sequences, candidal_labels, hypothesis_template)
inputs = self.tokenizer( inputs = self.tokenizer(
inputs, sequence_pairs,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
return_tensors=self.framework, return_tensors=self.framework,
padding=padding, padding=padding,
...@@ -1123,7 +1070,13 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -1123,7 +1070,13 @@ class ZeroShotClassificationPipeline(Pipeline):
return inputs return inputs
def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False): def __call__(
self,
sequences: Union[str, List[str]],
candidate_labels,
hypothesis_template="This example is {}.",
multi_class=False,
):
""" """
Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline` Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline`
documentation for more information. documentation for more information.
...@@ -1154,8 +1107,11 @@ class ZeroShotClassificationPipeline(Pipeline): ...@@ -1154,8 +1107,11 @@ class ZeroShotClassificationPipeline(Pipeline):
- **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood. - **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels. - **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
""" """
if sequences and isinstance(sequences, str):
sequences = [sequences]
outputs = super().__call__(sequences, candidate_labels, hypothesis_template) outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
num_sequences = 1 if isinstance(sequences, str) else len(sequences) num_sequences = len(sequences)
candidate_labels = self._args_parser._parse_labels(candidate_labels) candidate_labels = self._args_parser._parse_labels(candidate_labels)
reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1)) reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))
...@@ -1425,12 +1381,12 @@ class TokenClassificationPipeline(Pipeline): ...@@ -1425,12 +1381,12 @@ class TokenClassificationPipeline(Pipeline):
self.ignore_labels = ignore_labels self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities self.grouped_entities = grouped_entities
def __call__(self, *args, **kwargs): def __call__(self, inputs: Union[str, List[str]], **kwargs):
""" """
Classify each token of the text(s) given as inputs. Classify each token of the text(s) given as inputs.
Args: Args:
args (:obj:`str` or :obj:`List[str]`): inputs (:obj:`str` or :obj:`List[str]`):
One or several texts (or one list of texts) for token classification. One or several texts (or one list of texts) for token classification.
Return: Return:
...@@ -1444,7 +1400,8 @@ class TokenClassificationPipeline(Pipeline): ...@@ -1444,7 +1400,8 @@ class TokenClassificationPipeline(Pipeline):
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the - **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
corresponding token in the sentence. corresponding token in the sentence.
""" """
inputs = self._args_parser(*args, **kwargs) if isinstance(inputs, str):
inputs = [inputs]
answers = [] answers = []
for sentence in inputs: for sentence in inputs:
...@@ -1659,12 +1616,12 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -1659,12 +1616,12 @@ class QuestionAnsweringPipeline(Pipeline):
tokenizer=tokenizer, tokenizer=tokenizer,
modelcard=modelcard, modelcard=modelcard,
framework=framework, framework=framework,
args_parser=QuestionAnsweringArgumentHandler(),
device=device, device=device,
task=task, task=task,
**kwargs, **kwargs,
) )
self._args_parser = QuestionAnsweringArgumentHandler()
self.check_model_type( self.check_model_type(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
) )
...@@ -2489,12 +2446,11 @@ class ConversationalPipeline(Pipeline): ...@@ -2489,12 +2446,11 @@ class ConversationalPipeline(Pipeline):
else: else:
return output return output
def _parse_and_tokenize(self, *args, **kwargs): def _parse_and_tokenize(self, inputs, **kwargs):
""" """
Parse arguments and tokenize, adding an EOS token at the end of the user input Parse arguments and tokenize, adding an EOS token at the end of the user input
""" """
# Parse arguments # Parse arguments
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", []) inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
for input in inputs: for input in inputs:
input.append(self.tokenizer.eos_token_id) input.append(self.tokenizer.eos_token_id)
......
import unittest
from typing import List, Optional from typing import List, Optional
from transformers import is_tf_available, is_torch_available, pipeline from transformers import is_tf_available, is_torch_available, pipeline
from transformers.pipelines import DefaultArgumentHandler, Pipeline
# from transformers.pipelines import DefaultArgumentHandler, Pipeline
from transformers.pipelines import Pipeline
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
...@@ -200,74 +201,74 @@ class MonoInputPipelineCommonMixin: ...@@ -200,74 +201,74 @@ class MonoInputPipelineCommonMixin:
self.assertRaises(Exception, nlp, self.invalid_inputs) self.assertRaises(Exception, nlp, self.invalid_inputs)
@is_pipeline_test # @is_pipeline_test
class DefaultArgumentHandlerTestCase(unittest.TestCase): # class DefaultArgumentHandlerTestCase(unittest.TestCase):
def setUp(self) -> None: # def setUp(self) -> None:
self.handler = DefaultArgumentHandler() # self.handler = DefaultArgumentHandler()
#
def test_kwargs_x(self): # def test_kwargs_x(self):
mono_data = {"X": "This is a sample input"} # mono_data = {"X": "This is a sample input"}
mono_args = self.handler(**mono_data) # mono_args = self.handler(**mono_data)
#
self.assertTrue(isinstance(mono_args, list)) # self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1) # self.assertEqual(len(mono_args), 1)
#
multi_data = {"x": ["This is a sample input", "This is a second sample input"]} # multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data) # multi_args = self.handler(**multi_data)
#
self.assertTrue(isinstance(multi_args, list)) # self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2) # self.assertEqual(len(multi_args), 2)
#
def test_kwargs_data(self): # def test_kwargs_data(self):
mono_data = {"data": "This is a sample input"} # mono_data = {"data": "This is a sample input"}
mono_args = self.handler(**mono_data) # mono_args = self.handler(**mono_data)
#
self.assertTrue(isinstance(mono_args, list)) # self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1) # self.assertEqual(len(mono_args), 1)
#
multi_data = {"data": ["This is a sample input", "This is a second sample input"]} # multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data) # multi_args = self.handler(**multi_data)
#
self.assertTrue(isinstance(multi_args, list)) # self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2) # self.assertEqual(len(multi_args), 2)
#
def test_multi_kwargs(self): # def test_multi_kwargs(self):
mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"} # mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
mono_args = self.handler(**mono_data) # mono_args = self.handler(**mono_data)
#
self.assertTrue(isinstance(mono_args, list)) # self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 2) # self.assertEqual(len(mono_args), 2)
#
multi_data = { # multi_data = {
"data": ["This is a sample input", "This is a second sample input"], # "data": ["This is a sample input", "This is a second sample input"],
"test": ["This is a sample input 2", "This is a second sample input 2"], # "test": ["This is a sample input 2", "This is a second sample input 2"],
} # }
multi_args = self.handler(**multi_data) # multi_args = self.handler(**multi_data)
#
self.assertTrue(isinstance(multi_args, list)) # self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 4) # self.assertEqual(len(multi_args), 4)
#
def test_args(self): # def test_args(self):
mono_data = "This is a sample input" # mono_data = "This is a sample input"
mono_args = self.handler(mono_data) # mono_args = self.handler(mono_data)
#
self.assertTrue(isinstance(mono_args, list)) # self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1) # self.assertEqual(len(mono_args), 1)
#
mono_data = ["This is a sample input"] # mono_data = ["This is a sample input"]
mono_args = self.handler(mono_data) # mono_args = self.handler(mono_data)
#
self.assertTrue(isinstance(mono_args, list)) # self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1) # self.assertEqual(len(mono_args), 1)
#
multi_data = ["This is a sample input", "This is a second sample input"] # multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(multi_data) # multi_args = self.handler(multi_data)
#
self.assertTrue(isinstance(multi_args, list)) # self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2) # self.assertEqual(len(multi_args), 2)
#
multi_data = ["This is a sample input", "This is a second sample input"] # multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(*multi_data) # multi_args = self.handler(*multi_data)
#
self.assertTrue(isinstance(multi_args, list)) # self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2) # self.assertEqual(len(multi_args), 2)
import unittest import unittest
import pytest
from transformers import pipeline from transformers import pipeline
from transformers.testing_utils import require_tf, require_torch, slow from transformers.testing_utils import require_tf, require_torch, slow
...@@ -37,7 +39,7 @@ EXPECTED_FILL_MASK_TARGET_RESULT = [ ...@@ -37,7 +39,7 @@ EXPECTED_FILL_MASK_TARGET_RESULT = [
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "fill-mask" pipeline_task = "fill-mask"
pipeline_loading_kwargs = {"topk": 2} pipeline_loading_kwargs = {"top_k": 2}
small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator
large_models = ["distilroberta-base"] # Models tested with the @slow decorator large_models = ["distilroberta-base"] # Models tested with the @slow decorator
mandatory_keys = {"sequence", "score", "token"} mandatory_keys = {"sequence", "score", "token"}
...@@ -51,6 +53,28 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): ...@@ -51,6 +53,28 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
] ]
expected_check_keys = ["sequence"] expected_check_keys = ["sequence"]
@require_torch
def test_torch_topk_deprecation(self):
# At pipeline initialization only it was not enabled at pipeline
# call site before
with pytest.warns(FutureWarning, match=r".*use `top_k`.*"):
pipeline(task="fill-mask", model=self.small_models[0], topk=1)
@require_torch
def test_torch_fill_mask(self):
valid_inputs = "My name is <mask>"
nlp = pipeline(task="fill-mask", model=self.small_models[0])
outputs = nlp(valid_inputs)
self.assertIsInstance(outputs, list)
# This passes
outputs = nlp(valid_inputs, targets=[" Patrick", " Clara"])
self.assertIsInstance(outputs, list)
# This used to fail with `cannot mix args and kwargs`
outputs = nlp(valid_inputs, something=False)
self.assertIsInstance(outputs, list)
@require_torch @require_torch
def test_torch_fill_mask_with_targets(self): def test_torch_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"] valid_inputs = ["My name is <mask>"]
...@@ -94,7 +118,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): ...@@ -94,7 +118,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
model=model_name, model=model_name,
tokenizer=model_name, tokenizer=model_name,
framework="pt", framework="pt",
topk=2, top_k=2,
) )
mono_result = nlp(valid_inputs[0], targets=valid_targets) mono_result = nlp(valid_inputs[0], targets=valid_targets)
......
...@@ -17,7 +17,7 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte ...@@ -17,7 +17,7 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
sum = 0.0 sum = 0.0
for score in result["scores"]: for score in result["scores"]:
sum += score sum += score
self.assertAlmostEqual(sum, 1.0) self.assertAlmostEqual(sum, 1.0, places=5)
def _test_entailment_id(self, nlp: Pipeline): def _test_entailment_id(self, nlp: Pipeline):
config = nlp.model.config config = nlp.model.config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment