Unverified Commit bf905644 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Removing duplicated code for Translation,Summarization and Text2TextGeneration pipelines (#9433)

* Merging all duplicated codes for Text2TextPipeline while preserving
backward compat.

* Fixing TranslationPipeline Hierarchy + return_name

* torch import guard.

* Update isort version.

* Remove code from other PR disentanglement.

* Removed named example to something more agnostic.
parent f33a6f34
...@@ -6,7 +6,7 @@ from .base import PIPELINE_INIT_ARGS, Pipeline ...@@ -6,7 +6,7 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
...@@ -15,49 +15,53 @@ logger = logging.get_logger(__name__) ...@@ -15,49 +15,53 @@ logger = logging.get_logger(__name__)
@add_end_docstrings(PIPELINE_INIT_ARGS) @add_end_docstrings(PIPELINE_INIT_ARGS)
class SummarizationPipeline(Pipeline): class Text2TextGenerationPipeline(Pipeline):
""" """
Summarize news articles and other documents. Pipeline for text to text generation using seq2seq models.
This summarizing pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task This Text2TextGenerationPipeline pipeline can currently be loaded from :func:`~transformers.pipeline` using the
identifier: :obj:`"summarization"`. following task identifier: :obj:`"text2text-generation"`.
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'. See the up-to-date up-to-date list of available models on `huggingface.co/models <https://huggingface.co/models?filter=seq2seq>`__.
list of available models on `huggingface.co/models <https://huggingface.co/models?filter=summarization>`__.
Usage:: Usage::
# use bart in pytorch text2text_generator = pipeline("text2text-generation")
summarizer = pipeline("summarization") text2text_generator("question: What is 42 ? context: 42 is the answer to life, the universe and everything")
summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
# use t5 in tf
summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
summarizer("Sam Shleifer writes the best docstring examples in the whole world.", min_length=5, max_length=20)
""" """
# Used in the return key of the pipeline.
return_name = "generated"
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs.update(task="summarization")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.check_model_type( self.check_model_type(
TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
if self.framework == "tf"
else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
) )
def check_inputs(self, input_length: int, min_length: int, max_length: int):
"""
Checks wether there might be something wrong with given input with regard to the model.
"""
return True
def __call__( def __call__(
self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
): ):
r""" r"""
Summarize the text(s) given as inputs. Generate the output text(s) using text(s) given as inputs.
Args: Args:
documents (`str` or :obj:`List[str]`): args (:obj:`str` or :obj:`List[str]`):
One or several articles (or one list of articles) to summarize. Input text for the encoder.
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to include the decoded texts in the outputs
return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`): return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to include the tensors of predictions (as token indices) in the outputs. Whether or not to include the tensors of predictions (as token indices) in the outputs.
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output. Whether or not to clean up the potential extra spaces in the text output.
generate_kwargs: generate_kwargs:
...@@ -67,36 +71,32 @@ class SummarizationPipeline(Pipeline): ...@@ -67,36 +71,32 @@ class SummarizationPipeline(Pipeline):
Return: Return:
A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys: A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:
- **summary_text** (:obj:`str`, present when ``return_text=True``) -- The summary of the corresponding - **generated_text** (:obj:`str`, present when ``return_text=True``) -- The generated text.
input. - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
- **summary_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) -- -- The token ids of the generated text.
The token ids of the summary.
""" """
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True" assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
assert len(documents) > 0, "Please provide a document to summarize"
prefix = self.model.config.prefix if self.model.config.prefix is not None else "" prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(args[0], list):
if isinstance(documents[0], list):
assert ( assert (
self.tokenizer.pad_token_id is not None self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input" ), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
args = ([prefix + arg for arg in args[0]],)
documents = ([prefix + document for document in documents[0]],)
padding = True padding = True
elif isinstance(documents[0], str): elif isinstance(args[0], str):
documents = (prefix + documents[0],) args = (prefix + args[0],)
padding = False padding = False
else: else:
raise ValueError( raise ValueError(
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format( " `args[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
documents[0] args[0]
) )
) )
with self.device_placement(): with self.device_placement():
inputs = self._parse_and_tokenize(*documents, padding=padding) inputs = self._parse_and_tokenize(*args, padding=padding, **generate_kwargs)
if self.framework == "pt": if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs) inputs = self.ensure_tensor_on_device(**inputs)
...@@ -105,35 +105,25 @@ class SummarizationPipeline(Pipeline): ...@@ -105,35 +105,25 @@ class SummarizationPipeline(Pipeline):
input_length = tf.shape(inputs["input_ids"])[-1].numpy() input_length = tf.shape(inputs["input_ids"])[-1].numpy()
min_length = generate_kwargs.get("min_length", self.model.config.min_length) min_length = generate_kwargs.get("min_length", self.model.config.min_length)
if input_length < min_length // 2:
logger.warning(
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
min_length, input_length
)
)
max_length = generate_kwargs.get("max_length", self.model.config.max_length) max_length = generate_kwargs.get("max_length", self.model.config.max_length)
if input_length < max_length: self.check_inputs(input_length, min_length, max_length)
logger.warning(
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
max_length, input_length
)
)
summaries = self.model.generate( # truncation should be used by _parse_and_tokenize
generate_kwargs.pop("truncation", None)
generations = self.model.generate(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
**generate_kwargs, **generate_kwargs,
) )
results = [] results = []
for summary in summaries: for generation in generations:
record = {} record = {}
if return_tensors: if return_tensors:
record["summary_token_ids"] = summary record[f"{self.return_name}_token_ids"] = generation
if return_text: if return_text:
record["summary_text"] = self.tokenizer.decode( record[f"{self.return_name}_text"] = self.tokenizer.decode(
summary, generation,
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
) )
...@@ -142,42 +132,42 @@ class SummarizationPipeline(Pipeline): ...@@ -142,42 +132,42 @@ class SummarizationPipeline(Pipeline):
@add_end_docstrings(PIPELINE_INIT_ARGS) @add_end_docstrings(PIPELINE_INIT_ARGS)
class TranslationPipeline(Pipeline): class SummarizationPipeline(Text2TextGenerationPipeline):
""" """
Translates from one language to another. Summarize news articles and other documents.
This translation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task This summarizing pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
identifier: :obj:`"translation_xx_to_yy"`. identifier: :obj:`"summarization"`.
The models that this pipeline can use are models that have been fine-tuned on a translation task. See the The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
up-to-date list of available models on `huggingface.co/models currently, '`bart-large-cnn`', '`t5-small`', '`t5-base`', '`t5-large`', '`t5-3b`', '`t5-11b`'. See the up-to-date
<https://huggingface.co/models?filter=translation>`__. list of available models on `huggingface.co/models <https://huggingface.co/models?filter=summarization>`__.
Usage:: Usage::
en_fr_translator = pipeline("translation_en_to_fr")
en_fr_translator("How old are you?")
"""
def __init__(self, *args, **kwargs): # use bart in pytorch
super().__init__(*args, **kwargs) summarizer = pipeline("summarization")
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)
self.check_model_type( # use t5 in tf
TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING summarizer = pipeline("summarization", model="t5-base", tokenizer="t5-base", framework="tf")
) summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)
"""
def __call__( # Used in the return key of the pipeline.
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs return_name = "summary"
):
def __call__(self, *args, **kwargs):
r""" r"""
Translate the text(s) given as inputs. Summarize the text(s) given as inputs.
Args: Args:
args (:obj:`str` or :obj:`List[str]`): documents (`str` or :obj:`List[str]`):
Texts to be translated. One or several articles (or one list of articles) to summarize.
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to include the decoded texts in the outputs
return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`): return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to include the tensors of predictions (as token indices) in the outputs. Whether or not to include the tensors of predictions (as token indices) in the outputs.
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output. Whether or not to clean up the potential extra spaces in the text output.
generate_kwargs: generate_kwargs:
...@@ -187,104 +177,67 @@ class TranslationPipeline(Pipeline): ...@@ -187,104 +177,67 @@ class TranslationPipeline(Pipeline):
Return: Return:
A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys: A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:
- **translation_text** (:obj:`str`, present when ``return_text=True``) -- The translation. - **summary_text** (:obj:`str`, present when ``return_text=True``) -- The summary of the corresponding
- **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) input.
-- The token ids of the translation. - **summary_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) --
The token ids of the summary.
""" """
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True" return super().__call__(*args, **kwargs)
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(args[0], list):
assert (
self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
args = ([prefix + text for text in args[0]],)
padding = True
elif isinstance(args[0], str): def check_inputs(self, input_length: int, min_length: int, max_length: int) -> bool:
args = (prefix + args[0],) """
padding = False Checks wether there might be something wrong with given input with regard to the model.
else: """
raise ValueError( if input_length < min_length // 2:
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format( logger.warning(
args[0] "Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
min_length, input_length
) )
) )
with self.device_placement(): if input_length < max_length:
inputs = self._parse_and_tokenize(*args, padding=padding) logger.warning(
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
if self.framework == "pt": max_length, input_length
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]
elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
if input_length > 0.9 * max_length:
logger.warning(
"Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
input_length, max_length
)
) )
translations = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**generate_kwargs,
) )
results = []
for translation in translations:
record = {}
if return_tensors:
record["translation_token_ids"] = translation
if return_text:
record["translation_text"] = self.tokenizer.decode(
translation,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
results.append(record)
return results
@add_end_docstrings(PIPELINE_INIT_ARGS) @add_end_docstrings(PIPELINE_INIT_ARGS)
class Text2TextGenerationPipeline(Pipeline): class TranslationPipeline(Text2TextGenerationPipeline):
""" """
Pipeline for text to text generation using seq2seq models. Translates from one language to another.
This Text2TextGenerationPipeline pipeline can currently be loaded from :func:`~transformers.pipeline` using the This translation pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task
following task identifier: :obj:`"text2text-generation"`. identifier: :obj:`"translation_xx_to_yy"`.
The models that this pipeline can use are models that have been fine-tuned on a translation task. See the The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
up-to-date list of available models on `huggingface.co/models <https://huggingface.co/models?filter=seq2seq>`__. up-to-date list of available models on `huggingface.co/models
<https://huggingface.co/models?filter=translation>`__.
Usage:: Usage::
en_fr_translator = pipeline("translation_en_to_fr")
text2text_generator = pipeline("text2text-generation") en_fr_translator("How old are you?")
text2text_generator("question: What is 42 ? context: 42 is the answer to life, the universe and everything")
""" """
def __init__(self, *args, **kwargs): # Used in the return key of the pipeline.
super().__init__(*args, **kwargs) return_name = "translation"
self.check_model_type( def check_inputs(self, input_length: int, min_length: int, max_length: int):
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING if input_length > 0.9 * max_length:
if self.framework == "tf" logger.warning(
else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING "Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
) input_length, max_length
)
)
def __call__( def __call__(self, *args, **kwargs):
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):
r""" r"""
Generate the output text(s) using text(s) given as inputs. Translate the text(s) given as inputs.
Args: Args:
args (:obj:`str` or :obj:`List[str]`): args (:obj:`str` or :obj:`List[str]`):
Input text for the encoder. Texts to be translated.
return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`): return_tensors (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to include the tensors of predictions (as token indices) in the outputs. Whether or not to include the tensors of predictions (as token indices) in the outputs.
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`): return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
...@@ -298,48 +251,8 @@ class Text2TextGenerationPipeline(Pipeline): ...@@ -298,48 +251,8 @@ class Text2TextGenerationPipeline(Pipeline):
Return: Return:
A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys: A list or a list of list of :obj:`dict`: Each result comes as a dictionary with the following keys:
- **generated_text** (:obj:`str`, present when ``return_text=True``) -- The generated text. - **translation_text** (:obj:`str`, present when ``return_text=True``) -- The translation.
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) - **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
-- The token ids of the generated text. -- The token ids of the translation.
""" """
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True" return super().__call__(*args, **kwargs)
if isinstance(args[0], list):
assert (
self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
padding = True
elif isinstance(args[0], str):
padding = False
else:
raise ValueError(
" `documents[0]`: {} have the wrong format. The should be either of type `str` or type `list`".format(
args[0]
)
)
with self.device_placement():
inputs = self._parse_and_tokenize(*args, padding=padding)
if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
generations = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**generate_kwargs,
)
results = []
for generation in generations:
record = {}
if return_tensors:
record["generated_token_ids"] = generation
if return_text:
record["generated_text"] = self.tokenizer.decode(
generation,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
results.append(record)
return results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment