Unverified Commit 0a6cbea0 authored by Funtowicz Morgan's avatar Funtowicz Morgan Committed by GitHub
Browse files

Rewritten batch support in pipelines. (#4154)



* Rewritten batch support in pipelines.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Fix imports sorting 🔧

Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Set pad_to_max_length=True by default on Pipeline.

* Set pad_to_max_length=False for generation pipelines.

Most of generation models doesn't have padding token.

* Address @joeddav review comment: Uniformized *args.
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>

* Address @joeddav review comment: Uniformized *args (second).
Signed-off-by: default avatarMorgan Funtowicz <morgan@huggingface.co>
parent 99d1a694
...@@ -22,8 +22,9 @@ import pickle ...@@ -22,8 +22,9 @@ import pickle
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from itertools import chain
from os.path import abspath, exists from os.path import abspath, exists
from typing import List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
...@@ -96,19 +97,50 @@ class DefaultArgumentHandler(ArgumentHandler): ...@@ -96,19 +97,50 @@ class DefaultArgumentHandler(ArgumentHandler):
Default varargs argument parser handling parameters for each Pipeline Default varargs argument parser handling parameters for each Pipeline
""" """
def __call__(self, *args, **kwargs): @staticmethod
if "X" in kwargs: def handle_kwargs(kwargs: Dict) -> List:
return kwargs["X"] if len(kwargs) == 1:
elif "data" in kwargs: output = list(kwargs.values())
return kwargs["data"] else:
elif len(args) == 1: output = list(chain(kwargs.values()))
if isinstance(args[0], list):
return args[0] return DefaultArgumentHandler.handle_args(output)
else:
@staticmethod
def handle_args(args: Sequence[Any]) -> List[str]:
# Only one argument, let's do case by case
if len(args) == 1:
if isinstance(args[0], str):
return [args[0]] return [args[0]]
elif not isinstance(args[0], list):
return list(args)
else:
return args[0]
# Multiple arguments (x1, x2, ...)
elif len(args) > 1: elif len(args) > 1:
return list(args) if all([isinstance(arg, str) for arg in args]):
raise ValueError("Unable to infer the format of the provided data (X=, data=, ...)") return list(args)
# If not instance of list, then it should instance of iterable
elif isinstance(args, Iterable):
return list(chain.from_iterable(chain(args)))
else:
raise ValueError(
"Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args))
)
else:
return []
def __call__(self, *args, **kwargs):
if len(kwargs) > 0 and len(args) > 0:
raise ValueError("Pipeline cannot handle mixed args and kwargs")
if len(kwargs) > 0:
return DefaultArgumentHandler.handle_kwargs(kwargs)
else:
return DefaultArgumentHandler.handle_args(args)
class PipelineDataFormat: class PipelineDataFormat:
...@@ -418,20 +450,20 @@ class Pipeline(_ScikitCompat): ...@@ -418,20 +450,20 @@ class Pipeline(_ScikitCompat):
""" """
return {name: tensor.to(self.device) for name, tensor in inputs.items()} return {name: tensor.to(self.device) for name, tensor in inputs.items()}
def _parse_and_tokenize(self, *texts, pad_to_max_length=False, **kwargs): def _parse_and_tokenize(self, *args, pad_to_max_length=True, **kwargs):
""" """
Parse arguments and tokenize Parse arguments and tokenize
""" """
# Parse arguments # Parse arguments
inputs = self._args_parser(*texts, **kwargs) inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer.batch_encode_plus( inputs = self.tokenizer.batch_encode_plus(
inputs, add_special_tokens=True, return_tensors=self.framework, pad_to_max_length=pad_to_max_length, inputs, add_special_tokens=True, return_tensors=self.framework, pad_to_max_length=pad_to_max_length,
) )
return inputs return inputs
def __call__(self, *texts, **kwargs): def __call__(self, *args, **kwargs):
inputs = self._parse_and_tokenize(*texts, **kwargs) inputs = self._parse_and_tokenize(*args, **kwargs)
return self._forward(inputs) return self._forward(inputs)
def _forward(self, inputs, return_tensors=False): def _forward(self, inputs, return_tensors=False):
...@@ -550,18 +582,18 @@ class TextGenerationPipeline(Pipeline): ...@@ -550,18 +582,18 @@ class TextGenerationPipeline(Pipeline):
with people, even a bishop, begging for his blessing. <eod> </s> <eos>""" with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
def __call__( def __call__(
self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
): ):
text_inputs = self._args_parser(*texts) text_inputs = self._args_parser(*args)
results = [] results = []
for prompt_text in text_inputs: for prompt_text in text_inputs:
# Manage correct placement of the tensors # Manage correct placement of the tensors
with self.device_placement(): with self.device_placement():
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]: if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
inputs = self._parse_and_tokenize(self.PADDING_TEXT + prompt_text) inputs = self._parse_and_tokenize(self.PADDING_TEXT + prompt_text, pad_to_max_length=False)
else: else:
inputs = self._parse_and_tokenize(prompt_text) inputs = self._parse_and_tokenize(prompt_text, pad_to_max_length=False)
# set input_ids to None to allow empty prompt # set input_ids to None to allow empty prompt
if inputs["input_ids"].shape[-1] == 0: if inputs["input_ids"].shape[-1] == 0:
...@@ -825,8 +857,8 @@ class NerPipeline(Pipeline): ...@@ -825,8 +857,8 @@ class NerPipeline(Pipeline):
self._basic_tokenizer = BasicTokenizer(do_lower_case=False) self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self.ignore_labels = ignore_labels self.ignore_labels = ignore_labels
def __call__(self, *texts, **kwargs): def __call__(self, *args, **kwargs):
inputs = self._args_parser(*texts, **kwargs) inputs = self._args_parser(*args, **kwargs)
answers = [] answers = []
for sentence in inputs: for sentence in inputs:
...@@ -1016,7 +1048,7 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -1016,7 +1048,7 @@ class QuestionAnsweringPipeline(Pipeline):
else: else:
return SquadExample(None, question, context, None, None, None) return SquadExample(None, question, context, None, None, None)
def __call__(self, *texts, **kwargs): def __call__(self, *args, **kwargs):
""" """
Args: Args:
We support multiple use-cases, the following are exclusive: We support multiple use-cases, the following are exclusive:
...@@ -1046,7 +1078,7 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -1046,7 +1078,7 @@ class QuestionAnsweringPipeline(Pipeline):
raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"])) raise ValueError("max_answer_len parameter should be >= 1 (got {})".format(kwargs["max_answer_len"]))
# Convert inputs to features # Convert inputs to features
examples = self._args_parser(*texts, **kwargs) examples = self._args_parser(*args, **kwargs)
features_list = [ features_list = [
squad_convert_examples_to_features( squad_convert_examples_to_features(
[example], [example],
...@@ -1383,11 +1415,11 @@ class TranslationPipeline(Pipeline): ...@@ -1383,11 +1415,11 @@ class TranslationPipeline(Pipeline):
""" """
def __call__( def __call__(
self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
): ):
r""" r"""
Args: Args:
*texts: (list of strings) texts to be translated *args: (list of strings) texts to be translated
return_text: (bool, default=True) whether to add a decoded "translation_text" to each result return_text: (bool, default=True) whether to add a decoded "translation_text" to each result
return_tensors: (bool, default=False) whether to return the raw "translation_token_ids" to each result return_tensors: (bool, default=False) whether to return the raw "translation_token_ids" to each result
...@@ -1402,25 +1434,25 @@ class TranslationPipeline(Pipeline): ...@@ -1402,25 +1434,25 @@ class TranslationPipeline(Pipeline):
prefix = self.model.config.prefix if self.model.config.prefix is not None else "" prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(texts[0], list): if isinstance(args[0], list):
assert ( assert (
self.tokenizer.pad_token_id is not None self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input" ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
texts = ([prefix + text for text in texts[0]],) args = ([prefix + text for text in args[0]],)
pad_to_max_length = True pad_to_max_length = True
elif isinstance(texts[0], str): elif isinstance(args[0], str):
texts = (prefix + texts[0],) args = (prefix + args[0],)
pad_to_max_length = False pad_to_max_length = False
else: else:
raise ValueError( raise ValueError(
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format( " `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
texts[0] args[0]
) )
) )
with self.device_placement(): with self.device_placement():
inputs = self._parse_and_tokenize(*texts, pad_to_max_length=pad_to_max_length) inputs = self._parse_and_tokenize(*args, pad_to_max_length=pad_to_max_length)
if self.framework == "pt": if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs) inputs = self.ensure_tensor_on_device(**inputs)
......
...@@ -2,7 +2,7 @@ import unittest ...@@ -2,7 +2,7 @@ import unittest
from typing import Iterable, List, Optional from typing import Iterable, List, Optional
from transformers import pipeline from transformers import pipeline
from transformers.pipelines import Pipeline from transformers.pipelines import DefaultArgumentHandler, Pipeline
from .utils import require_tf, require_torch, slow from .utils import require_tf, require_torch, slow
...@@ -86,6 +86,78 @@ TRANSLATION_FINETUNED_MODELS = { ...@@ -86,6 +86,78 @@ TRANSLATION_FINETUNED_MODELS = {
TF_TRANSLATION_FINETUNED_MODELS = {("patrickvonplaten/t5-tiny-random", "t5-small", "translation_en_to_fr")} TF_TRANSLATION_FINETUNED_MODELS = {("patrickvonplaten/t5-tiny-random", "t5-small", "translation_en_to_fr")}
class DefaultArgumentHandlerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.handler = DefaultArgumentHandler()
def test_kwargs_x(self):
mono_data = {"X": "This is a sample input"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
def test_kwargs_data(self):
mono_data = {"data": "This is a sample input"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
def test_multi_kwargs(self):
mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
mono_args = self.handler(**mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 2)
multi_data = {
"data": ["This is a sample input", "This is a second sample input"],
"test": ["This is a sample input 2", "This is a second sample input 2"],
}
multi_args = self.handler(**multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 4)
def test_args(self):
mono_data = "This is a sample input"
mono_args = self.handler(mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
mono_data = ["This is a sample input"]
mono_args = self.handler(mono_data)
self.assertTrue(isinstance(mono_args, list))
self.assertEqual(len(mono_args), 1)
multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
multi_data = ["This is a sample input", "This is a second sample input"]
multi_args = self.handler(*multi_data)
self.assertTrue(isinstance(multi_args, list))
self.assertEqual(len(multi_args), 2)
class MonoColumnInputTestCase(unittest.TestCase): class MonoColumnInputTestCase(unittest.TestCase):
def _test_mono_column_pipeline( def _test_mono_column_pipeline(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment