Unverified Commit 995a958d authored by Boris Dayma's avatar Boris Dayma Committed by GitHub
Browse files

feat: allow prefix for any generative model (#5885)



* feat: allow padding_text for any generative model

* docs(pipelines.py): correct typo

* Update src/transformers/pipelines.py
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>

* feat: rename padding_text to prefix

* fix: cannot tokenize empty text

* fix: pass prefix arg to pipeline

* test: add prefix to text-generetation pipeline

* style: fix style

* style: clean code and variable name more explicit

* set arg docstring to optional
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent ce37be9d
...@@ -61,7 +61,7 @@ MODEL_CLASSES = { ...@@ -61,7 +61,7 @@ MODEL_CLASSES = {
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology # in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered. (except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia, remainder of the story. 1883 Western Siberia,
...@@ -122,12 +122,14 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text): ...@@ -122,12 +122,14 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
def prepare_xlnet_input(args, _, tokenizer, prompt_text): def prepare_xlnet_input(args, _, tokenizer, prompt_text):
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
prompt_text = prefix + prompt_text
return prompt_text return prompt_text
def prepare_transfoxl_input(args, _, tokenizer, prompt_text): def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
prompt_text = prefix + prompt_text
return prompt_text return prompt_text
...@@ -182,7 +184,8 @@ def main(): ...@@ -182,7 +184,8 @@ def main():
parser.add_argument("--k", type=int, default=0) parser.add_argument("--k", type=int, default=0)
parser.add_argument("--p", type=float, default=0.9) parser.add_argument("--p", type=float, default=0.9)
parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.") parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
...@@ -241,7 +244,8 @@ def main(): ...@@ -241,7 +244,8 @@ def main():
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
) )
else: else:
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") prefix = args.prefix if args.prefix else args.padding_text
encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(args.device) encoded_prompt = encoded_prompt.to(args.device)
if encoded_prompt.size()[-1] == 0: if encoded_prompt.size()[-1] == 0:
......
...@@ -752,11 +752,11 @@ class TextGenerationPipeline(Pipeline): ...@@ -752,11 +752,11 @@ class TextGenerationPipeline(Pipeline):
`huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__. `huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
""" """
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology # in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family XL_PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered. (except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia, remainder of the story. 1883 Western Siberia,
...@@ -765,7 +765,7 @@ class TextGenerationPipeline(Pipeline): ...@@ -765,7 +765,7 @@ class TextGenerationPipeline(Pipeline):
father initially slaps him for making such an accusation, Rasputin watches as the father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. """ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
ALLOWED_MODELS = [ ALLOWED_MODELS = [
"XLNetLMHeadModel", "XLNetLMHeadModel",
...@@ -809,7 +809,13 @@ class TextGenerationPipeline(Pipeline): ...@@ -809,7 +809,13 @@ class TextGenerationPipeline(Pipeline):
return inputs return inputs
def __call__( def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self,
*args,
return_tensors=False,
return_text=True,
clean_up_tokenization_spaces=False,
prefix=None,
**generate_kwargs
): ):
""" """
Complete the prompt(s) given as inputs. Complete the prompt(s) given as inputs.
...@@ -823,6 +829,8 @@ class TextGenerationPipeline(Pipeline): ...@@ -823,6 +829,8 @@ class TextGenerationPipeline(Pipeline):
Whether or not to include the decoded texts in the outputs. Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output. Whether or not to clean up the potential extra spaces in the text output.
prefix (:obj:`str`, `optional`):
Prefix added to prompt.
generate_kwargs: generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate Additional keyword arguments to pass along to the generate method of the model (see the generate
method corresponding to your framework `here <./model.html#generative-models>`__). method corresponding to your framework `here <./model.html#generative-models>`__).
...@@ -841,27 +849,27 @@ class TextGenerationPipeline(Pipeline): ...@@ -841,27 +849,27 @@ class TextGenerationPipeline(Pipeline):
for prompt_text in text_inputs: for prompt_text in text_inputs:
# Manage correct placement of the tensors # Manage correct placement of the tensors
with self.device_placement(): with self.device_placement():
if self.model.__class__.__name__ in [ prefix = prefix if prefix is not None else self.model.config.prefix
if prefix is None and self.model.__class__.__name__ in [
"XLNetLMHeadModel", "XLNetLMHeadModel",
"TransfoXLLMHeadModel", "TransfoXLLMHeadModel",
"TFXLNetLMHeadModel", "TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel", "TFTransfoXLLMHeadModel",
]: ]:
# For XLNet and TransformerXL we had an article to the prompt to give more state to the model. # For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
padding_text = self.PADDING_TEXT + self.tokenizer.eos_token prefix = self.XL_PREFIX
padding = self._parse_and_tokenize(padding_text, padding=False, add_special_tokens=False)
if prefix:
prefix_inputs = self._parse_and_tokenize(prefix, padding=False, add_special_tokens=False)
# This impacts max_length and min_length argument that need adjusting. # This impacts max_length and min_length argument that need adjusting.
padding_length = padding["input_ids"].shape[-1] prefix_length = prefix_inputs["input_ids"].shape[-1]
if "max_length" in generate_kwargs and generate_kwargs["max_length"] is not None: if generate_kwargs.get("max_length", None) is not None:
generate_kwargs["max_length"] += padding_length generate_kwargs["max_length"] += prefix_length
if "min_length" in generate_kwargs and generate_kwargs["min_length"] is not None: if generate_kwargs.get("min_length", None) is not None:
generate_kwargs["min_length"] += padding_length generate_kwargs["min_length"] += prefix_length
inputs = self._parse_and_tokenize( prefix = prefix or ""
padding_text + prompt_text, padding=False, add_special_tokens=False inputs = self._parse_and_tokenize(prefix + prompt_text, padding=False, add_special_tokens=False)
)
else:
inputs = self._parse_and_tokenize(prompt_text, padding=False, add_special_tokens=False)
# set input_ids to None to allow empty prompt # set input_ids to None to allow empty prompt
if inputs["input_ids"].shape[-1] == 0: if inputs["input_ids"].shape[-1] == 0:
......
...@@ -424,12 +424,14 @@ class MonoColumnInputTestCase(unittest.TestCase): ...@@ -424,12 +424,14 @@ class MonoColumnInputTestCase(unittest.TestCase):
for model_name in TEXT_GENERATION_FINETUNED_MODELS: for model_name in TEXT_GENERATION_FINETUNED_MODELS:
nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="pt") nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="pt")
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}) self._test_mono_column_pipeline(nlp, VALID_INPUTS, {})
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}, prefix="This is ")
@require_tf @require_tf
def test_tf_text_generation(self): def test_tf_text_generation(self):
for model_name in TEXT_GENERATION_FINETUNED_MODELS: for model_name in TEXT_GENERATION_FINETUNED_MODELS:
nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="tf") nlp = pipeline(task="text-generation", model=model_name, tokenizer=model_name, framework="tf")
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}) self._test_mono_column_pipeline(nlp, VALID_INPUTS, {})
self._test_mono_column_pipeline(nlp, VALID_INPUTS, {}, prefix="This is ")
@slow @slow
@require_torch @require_torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment