Unverified Commit c2d0ffec authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding a new `return_full_text` parameter to TextGenerationPipeline. (#9852)

* Adding a new `return_full_text` parameter to TextGenerationPipeline.

For text-generation, it's sometimes used as prompting text.
In that context, prefixing `generated_text` with the actual input
forces the caller to take an extra step to remove it.

The proposed change adds a new parameter (for backward compatibility).
`return_full_text` that enables the caller to prevent adding the prefix.

* Doc quality.
parent bc109ae5
...@@ -44,10 +44,11 @@ class TextGenerationPipeline(Pipeline): ...@@ -44,10 +44,11 @@ class TextGenerationPipeline(Pipeline):
"TFCTRLLMHeadModel", "TFCTRLLMHeadModel",
] ]
def __init__(self, *args, **kwargs): def __init__(self, *args, return_full_text=True, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.check_model_type(self.ALLOWED_MODELS) self.check_model_type(self.ALLOWED_MODELS)
self.return_full_text = return_full_text
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, **kwargs): def _parse_and_tokenize(self, *args, **kwargs):
...@@ -65,6 +66,7 @@ class TextGenerationPipeline(Pipeline): ...@@ -65,6 +66,7 @@ class TextGenerationPipeline(Pipeline):
text_inputs, text_inputs,
return_tensors=False, return_tensors=False,
return_text=True, return_text=True,
return_full_text=None,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
prefix=None, prefix=None,
**generate_kwargs **generate_kwargs
...@@ -79,6 +81,9 @@ class TextGenerationPipeline(Pipeline): ...@@ -79,6 +81,9 @@ class TextGenerationPipeline(Pipeline):
Whether or not to include the tensors of predictions (as token indices) in the outputs. Whether or not to include the tensors of predictions (as token indices) in the outputs.
return_text (:obj:`bool`, `optional`, defaults to :obj:`True`): return_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to include the decoded texts in the outputs. Whether or not to include the decoded texts in the outputs.
return_full_text (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`False` only added text is returned, otherwise the full text is returned Only meaningful
if `return_text` is set to True.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output. Whether or not to clean up the potential extra spaces in the text output.
prefix (:obj:`str`, `optional`): prefix (:obj:`str`, `optional`):
...@@ -94,6 +99,8 @@ class TextGenerationPipeline(Pipeline): ...@@ -94,6 +99,8 @@ class TextGenerationPipeline(Pipeline):
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) - **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
-- The token ids of the generated text. -- The token ids of the generated text.
""" """
prefix = prefix if prefix is not None else self.model.config.prefix
return_full_text = return_full_text if return_full_text is not None else self.return_full_text
if isinstance(text_inputs, str): if isinstance(text_inputs, str):
text_inputs = [text_inputs] text_inputs = [text_inputs]
...@@ -101,7 +108,6 @@ class TextGenerationPipeline(Pipeline): ...@@ -101,7 +108,6 @@ class TextGenerationPipeline(Pipeline):
for prompt_text in text_inputs: for prompt_text in text_inputs:
# Manage correct placement of the tensors # Manage correct placement of the tensors
with self.device_placement(): with self.device_placement():
prefix = prefix if prefix is not None else self.model.config.prefix
if prefix is None and self.model.__class__.__name__ in [ if prefix is None and self.model.__class__.__name__ in [
"XLNetLMHeadModel", "XLNetLMHeadModel",
"TransfoXLLMHeadModel", "TransfoXLLMHeadModel",
...@@ -168,7 +174,12 @@ class TextGenerationPipeline(Pipeline): ...@@ -168,7 +174,12 @@ class TextGenerationPipeline(Pipeline):
) )
) )
record["generated_text"] = prompt_text + text[prompt_length:] if return_full_text:
all_text = prompt_text + text[prompt_length:]
else:
all_text = text[prompt_length:]
record["generated_text"] = all_text
result.append(record) result.append(record)
results += [result] results += [result]
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
from transformers import pipeline from transformers import pipeline
from transformers.testing_utils import require_torch
from .test_pipelines_common import MonoInputPipelineCommonMixin from .test_pipelines_common import MonoInputPipelineCommonMixin
...@@ -41,3 +42,21 @@ class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas ...@@ -41,3 +42,21 @@ class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
self.assertEqual(type(outputs[0][0]["generated_text"]), str) self.assertEqual(type(outputs[0][0]["generated_text"]), str)
self.assertEqual(list(outputs[1][0].keys()), ["generated_text"]) self.assertEqual(list(outputs[1][0].keys()), ["generated_text"])
self.assertEqual(type(outputs[1][0]["generated_text"]), str) self.assertEqual(type(outputs[1][0]["generated_text"]), str)
@require_torch
def test_generation_output_style(self):
text_generator = pipeline(task="text-generation", model=self.small_models[0])
# text-generation is non-deterministic by nature, we can't fully test the output
outputs = text_generator("This is a test")
self.assertIn("This is a test", outputs[0]["generated_text"])
outputs = text_generator("This is a test", return_full_text=False)
self.assertNotIn("This is a test", outputs[0]["generated_text"])
text_generator = pipeline(task="text-generation", model=self.small_models[0], return_full_text=False)
outputs = text_generator("This is a test")
self.assertNotIn("This is a test", outputs[0]["generated_text"])
outputs = text_generator("This is a test", return_full_text=True)
self.assertIn("This is a test", outputs[0]["generated_text"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment