Unverified Commit 38a555a8 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Add Summarization to Pipelines (#3128)

* passing

* Undo stupid chg

* docs

* undo rename

* delete-cruft

* only import if you have torch

* Dont rely on dict ordering

* Fix dict ordering upstream

* docstring link

* docstring link

* remove trailing comma for 3.5 compat

* new name

* delegate kwarging

* Update kwargs
parent 2b60a26b
......@@ -61,3 +61,8 @@ QuestionAnsweringPipeline
.. autoclass:: transformers.QuestionAnsweringPipeline
SummarizationPipeline
==========================================
.. autoclass:: transformers.SummarizationPipeline
......@@ -113,6 +113,7 @@ from .pipelines import (
Pipeline,
PipelineDataFormat,
QuestionAnsweringPipeline,
SummarizationPipeline,
TextClassificationPipeline,
TokenClassificationPipeline,
pipeline,
......
......@@ -60,6 +60,7 @@ if is_torch_available():
AutoModelForTokenClassification,
AutoModelWithLMHead,
)
from .modeling_bart import BartForConditionalGeneration
logger = logging.getLogger(__name__)
......@@ -1104,6 +1105,107 @@ class QuestionAnsweringPipeline(Pipeline):
return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)}
class SummarizationPipeline(Pipeline):
"""
Summarize news articles and other documents
Usage::
summarizer = pipeline("summarization")
summarizer("Sam Shleifer writes the best docstring examples in the whole world.")
Supported Models:
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
currently only ``BartForConditionalGeneration.from_pretrained('bart-large-cnn')``
Arguments:
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
checkpoint identifier or an actual pre-trained model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
If :obj:`None`, the default of the pipeline will be loaded.
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
:class:`~transformers.PreTrainedTokenizer`.
If :obj:`None`, the default of the pipeline will be loaded.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
"""
task = "summarization"
def __call__(
self,
*documents,
return_tensors=False,
return_text=True,
max_length=142,
min_length=21,
clean_up_tokenization_spaces=False,
**generate_kwargs
):
r"""
Args:
*documents: (list of strings) articles to be summarized
return_text: (bool, default=True) whether to add a decoded "summary_text" to each result
return_tensors: (bool, default=False) whether to return the raw "summary_token_ids" to each result
max_length: (`optional`) int
The max length of the sequence to be generated. Does not include tokens in input_ids.
min_len: (`optional`) int
no_repeat_ngram_size: (`optional`) int. ban ngrams of this length from being repeated in the generated text
clean_up_tokenization_spaces: (`optional`) bool whether to include extra spaces in the output
**generate_kwargs: extra kwargs passed to `self.model.generate`_
Returns:
list of dicts with 'summary_text' and/or 'summary_token_ids' for each document_to_summarize
.. _`self.model.generate`:
https://huggingface.co/transformers/model_doc/bart.html#transformers.BartForConditionalGeneration.generate
"""
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
if self.framework == "tf":
raise NotImplementedError("Tensorflow not supported")
with self.device_placement():
inputs = self._parse_and_tokenize(*documents)
inputs = self.ensure_tensor_on_device(**inputs)
summaries = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length,
min_length=min_length,
do_sample=False,
**generate_kwargs,
)
results = []
for summary in summaries:
record = {}
if return_tensors:
record["summary_token_ids"] = summary
if return_text:
record["summary_text"] = self.tokenizer.decode(
summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces
)
results.append(record)
return results
# Register all the supported task here
SUPPORTED_TASKS = {
"feature-extraction": {
......@@ -1162,6 +1264,16 @@ SUPPORTED_TASKS = {
"tokenizer": ("distilroberta-base", {"use_fast": False}),
},
},
"summarization": {
"impl": SummarizationPipeline,
"pt": BartForConditionalGeneration if is_torch_available() else None,
"tf": None,
"default": {
"model": {"pt": "bart-large-cnn", "tf": None},
"config": None,
"tokenizer": ("bart-large-cnn", {"use_fast": False}),
},
},
}
......@@ -1253,7 +1365,7 @@ def pipeline(
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
models, config, tokenizer = tuple(targeted_task["default"].values())
models, config, tokenizer = [targeted_task["default"][k] for k in ["model", "config", "tokenizer"]]
model = models[framework]
# Try to infer tokenizer from model or config name (if provided as str)
......
......@@ -247,6 +247,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
expected_check_keys=["sequence"],
)
@require_torch
def test_summarization(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["summary_text"]
nlp = pipeline(task="summarization")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)
class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment