Unverified Commit 022e8fab authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Adds translation pipeline (#3419)

* fix merge conflicts

* add t5 summarization example

* change parameters for t5 summarization

* make style

* add first code snippet for translation

* only add prefixes

* add prefix patterns

* make style

* renaming

* fix conflicts

* remove unused patterns

* solve conflicts

* fix merge conflicts

* remove translation example

* remove summarization example

* make sure tensors are in numpy for float comparsion

* re-add t5 config

* fix t5 import config typo

* make style

* remove unused numpy statements

* update doctstring

* import translation pipeline
parent 3c5c5675
......@@ -116,6 +116,7 @@ from .pipelines import (
SummarizationPipeline,
TextClassificationPipeline,
TokenClassificationPipeline,
TranslationPipeline,
pipeline,
)
from .tokenization_albert import AlbertTokenizer
......
......@@ -130,7 +130,9 @@ class PipelineDataFormat:
SUPPORTED_FORMATS = ["json", "csv", "pipe"]
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
def __init__(
self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
):
self.output_path = output_path
self.input_path = input_path
self.column = column.split(",") if column is not None else [""]
......@@ -176,7 +178,7 @@ class PipelineDataFormat:
@staticmethod
def from_str(
format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False
format: str, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
):
if format == "json":
return JsonPipelineDataFormat(output_path, input_path, column, overwrite=overwrite)
......@@ -189,7 +191,9 @@ class PipelineDataFormat:
class CsvPipelineDataFormat(PipelineDataFormat):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
def __init__(
self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
):
super().__init__(output_path, input_path, column, overwrite=overwrite)
def __iter__(self):
......@@ -210,7 +214,9 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class JsonPipelineDataFormat(PipelineDataFormat):
def __init__(self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False):
def __init__(
self, output_path: Optional[str], input_path: Optional[str], column: Optional[str], overwrite=False,
):
super().__init__(output_path, input_path, column, overwrite=overwrite)
with open(input_path, "r") as f:
......@@ -1120,7 +1126,11 @@ class QuestionAnsweringPipeline(Pipeline):
chars_idx += len(word) + 1
# Join text with spaces
return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)}
return {
"answer": " ".join(words),
"start": max(0, char_start_idx),
"end": min(len(text), char_end_idx),
}
class SummarizationPipeline(Pipeline):
......@@ -1223,18 +1233,18 @@ class SummarizationPipeline(Pipeline):
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]
elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1]
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
if input_length < self.model.config.min_length // 2:
logger.warning(
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length in config and insert config manually".format(
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
self.model.config.min_length, input_length
)
)
if input_length < self.model.config.max_length:
logger.warning(
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length in config and insert config manually".format(
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
self.model.config.max_length, input_length
)
)
......@@ -1250,7 +1260,115 @@ class SummarizationPipeline(Pipeline):
record["summary_token_ids"] = summary
if return_text:
record["summary_text"] = self.tokenizer.decode(
summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces
summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
results.append(record)
return results
class TranslationPipeline(Pipeline):
"""
Translates from one language to another.
Usage::
en_fr_translator = pipeline("translation_en_to_fr")
en_fr_translator("How old are you?")
Supported Models: "t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"
Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
"""
def __call__(
self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):
r"""
Args:
*texts: (list of strings) articles to be summarized
return_text: (bool, default=True) whether to add a decoded "translation_text" to each result
return_tensors: (bool, default=False) whether to return the raw "translation_token_ids" to each result
**generate_kwargs: extra kwargs passed to `self.model.generate`_
Returns:
list of dicts with 'translation_text' and/or 'translation_token_ids' for each text_to_translate
.. _`self.model.generate`:
https://huggingface.co/transformers/model_doc/bart.html#transformers.BartForConditionalGeneration.generate
"""
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(texts[0], list):
assert (
self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
texts = ([prefix + text for text in texts[0]],)
pad_to_max_length = True
elif isinstance(texts[0], str):
texts = (prefix + texts[0],)
pad_to_max_length = False
else:
raise ValueError(
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
texts[0]
)
)
with self.device_placement():
inputs = self._parse_and_tokenize(*texts, pad_to_max_length=pad_to_max_length)
if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]
elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
if input_length > 0.9 * self.model.config.max_length:
logger.warning(
"Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
input_length, self.model.config.max_length
)
)
translations = self.model.generate(
inputs["input_ids"], attention_mask=inputs["attention_mask"], **generate_kwargs,
)
results = []
for translation in translations:
record = {}
if return_tensors:
record["translation_token_ids"] = translation
if return_text:
record["translation_text"] = self.tokenizer.decode(
translation,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
results.append(record)
return results
......@@ -1324,6 +1442,36 @@ SUPPORTED_TASKS = {
"tokenizer": ("bart-large-cnn", {"use_fast": False}),
},
},
"translation_en_to_fr": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"default": {
"model": {"pt": "t5-base", "tf": "t5-base"},
"config": None,
"tokenizer": ("t5-base", {"use_fast": False}),
},
},
"translation_en_to_de": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"default": {
"model": {"pt": "t5-base", "tf": "t5-base"},
"config": None,
"tokenizer": ("t5-base", {"use_fast": False}),
},
},
"translation_en_to_ro": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"default": {
"model": {"pt": "t5-base", "tf": "t5-base"},
"config": None,
"tokenizer": ("t5-base", {"use_fast": False}),
},
},
}
......@@ -1472,4 +1620,4 @@ def pipeline(
)
model = model_class.from_pretrained(model, config=config, **model_kwargs)
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs,)
......@@ -81,6 +81,12 @@ TF_FILL_MASK_FINETUNED_MODELS = [
SUMMARIZATION_FINETUNED_MODELS = {("bart-large-cnn", "bart-large-cnn"), ("t5-small", "t5-small")}
TF_SUMMARIZATION_FINETUNED_MODELS = {("t5-small", "t5-small")}
TRANSLATION_FINETUNED_MODELS = {
("t5-small", "t5-small", "translation_en_to_de"),
("t5-small", "t5-small", "translation_en_to_ro"),
}
TF_TRANSLATION_FINETUNED_MODELS = {("t5-small", "t5-small", "translation_en_to_fr")}
class MonoColumnInputTestCase(unittest.TestCase):
def _test_mono_column_pipeline(
......@@ -272,6 +278,28 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
@require_torch
def test_translation(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["translation_text"]
for model, tokenizer, task in TRANSLATION_FINETUNED_MODELS:
nlp = pipeline(task=task, model=model, tokenizer=tokenizer)
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
@require_tf
def test_tf_translation(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["translation_text"]
for model, tokenizer, task in TF_TRANSLATION_FINETUNED_MODELS:
nlp = pipeline(task=task, model=model, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment