"examples/vscode:/vscode.git/clone" did not exist on "77f4c46b501322e9bffb5416dfbf0397deefd7d8"
Unverified Commit dc540dd3 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding `handle_long_generation` paramters for `text-generation` pipeline. (#14118)

* Adding `handle_long_generation` paramters for `text-generation` pipeline.

* More error handling

* Fixing tests by dropping tf support on this functionality, it needs

`max_new_tokens` to make it possible to understand user's intent.
Otherwise, `max_length` == `tokenizer.model_max_length` <
input_ids.shape[0].

* Fixing doc ?

* Doc ?

* Remove link from doc.

* Catched an issue on roberta.

* Damn doc.

* Non BC proposal ?

* Cleaning the fix ?

* Finally using only a test override.

* Don't need to modify this.

* Bad print.
parent d37f1fb8
...@@ -254,7 +254,7 @@ class ReformerEmbeddings(nn.Module): ...@@ -254,7 +254,7 @@ class ReformerEmbeddings(nn.Module):
if position_ids.shape[-1] > self.max_position_embeddings: if position_ids.shape[-1] > self.max_position_embeddings:
raise ValueError( raise ValueError(
f"Sequence Length: {position_ids.shape[-1]} has to be larger equal than " f"Sequence Length: {position_ids.shape[-1]} has to be less or equal than "
f"config.max_position_embeddings {self.max_position_embeddings}." f"config.max_position_embeddings {self.max_position_embeddings}."
) )
......
...@@ -75,6 +75,7 @@ class TextGenerationPipeline(Pipeline): ...@@ -75,6 +75,7 @@ class TextGenerationPipeline(Pipeline):
return_type=None, return_type=None,
clean_up_tokenization_spaces=None, clean_up_tokenization_spaces=None,
prefix=None, prefix=None,
handle_long_generation=None,
**generate_kwargs **generate_kwargs
): ):
preprocess_params = {} preprocess_params = {}
...@@ -85,14 +86,24 @@ class TextGenerationPipeline(Pipeline): ...@@ -85,14 +86,24 @@ class TextGenerationPipeline(Pipeline):
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
) )
prefix_length = prefix_inputs["input_ids"].shape[-1] prefix_length = prefix_inputs["input_ids"].shape[-1]
if "max_length" in generate_kwargs:
if "max_new_tokens" in generate_kwargs:
pass
elif "max_length" in generate_kwargs:
generate_kwargs["max_length"] += prefix_length generate_kwargs["max_length"] += prefix_length
else: else:
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
if "min_length" in generate_kwargs: if "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length generate_kwargs["min_length"] += prefix_length
if handle_long_generation is not None:
if handle_long_generation not in {"hole"}:
raise ValueError(
f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected [None, 'hole']"
)
preprocess_params["handle_long_generation"] = handle_long_generation
preprocess_params.update(generate_kwargs)
forward_params = generate_kwargs forward_params = generate_kwargs
postprocess_params = {} postprocess_params = {}
...@@ -136,6 +147,16 @@ class TextGenerationPipeline(Pipeline): ...@@ -136,6 +147,16 @@ class TextGenerationPipeline(Pipeline):
Whether or not to clean up the potential extra spaces in the text output. Whether or not to clean up the potential extra spaces in the text output.
prefix (:obj:`str`, `optional`): prefix (:obj:`str`, `optional`):
Prefix added to prompt. Prefix added to prompt.
handle_long_generation (:obj:`str`, `optional`):
By default, this pipelines does not handle long generation (ones that exceed in one form or the other
the model maximum length). There is no perfect way to adress this (more info
:https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common
strategies to work around that problem depending on your use case.
- :obj:`None` : default strategy where nothing in particular happens
- :obj:`"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
truncate a lot of the prompt and not suitable when generation exceed the model capacity)
generate_kwargs: generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework `here <./model.html#generative-models>`__). corresponding to your framework `here <./model.html#generative-models>`__).
...@@ -149,11 +170,31 @@ class TextGenerationPipeline(Pipeline): ...@@ -149,11 +170,31 @@ class TextGenerationPipeline(Pipeline):
""" """
return super().__call__(text_inputs, **kwargs) return super().__call__(text_inputs, **kwargs)
def preprocess(self, prompt_text, prefix=""): def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
inputs = self.tokenizer( inputs = self.tokenizer(
prefix + prompt_text, padding=False, add_special_tokens=False, return_tensors=self.framework prefix + prompt_text, padding=False, add_special_tokens=False, return_tensors=self.framework
) )
inputs["prompt_text"] = prompt_text inputs["prompt_text"] = prompt_text
if handle_long_generation == "hole":
cur_len = inputs["input_ids"].shape[-1]
if "max_new_tokens" in generate_kwargs:
new_tokens = generate_kwargs["max_new_tokens"]
else:
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
if new_tokens < 0:
raise ValueError("We cannot infer how many new tokens are expected")
if cur_len + new_tokens > self.tokenizer.model_max_length:
keep_length = self.tokenizer.model_max_length - new_tokens
if keep_length <= 0:
raise ValueError(
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the models max length"
)
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
if "attention_mask" in inputs:
inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]
return inputs return inputs
def _forward(self, model_inputs, **generate_kwargs): def _forward(self, model_inputs, **generate_kwargs):
......
...@@ -143,7 +143,9 @@ class PipelineTestCaseMeta(type): ...@@ -143,7 +143,9 @@ class PipelineTestCaseMeta(type):
try: try:
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint) tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
# XLNet actually defines it as -1. # XLNet actually defines it as -1.
if ( if model.config.__class__.__name__ == "RobertaConfig":
tokenizer.model_max_length = model.config.max_position_embeddings - 2
elif (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings > 0 and model.config.max_position_embeddings > 0
): ):
......
...@@ -123,3 +123,24 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM ...@@ -123,3 +123,24 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
else: else:
with self.assertRaises((ValueError, AssertionError)): with self.assertRaises((ValueError, AssertionError)):
outputs = text_generator("") outputs = text_generator("")
if text_generator.framework == "tf":
# TF generation does not support max_new_tokens, and it's impossible
# to control long generation with only max_length without
# fancy calculation, dismissing tests for now.
return
# We don't care about infinite range models.
# They already work.
if tokenizer.model_max_length < 10000:
# Handling of large generations
with self.assertRaises((RuntimeError, IndexError, ValueError, AssertionError)):
text_generator("This is a test" * 500, max_new_tokens=20)
outputs = text_generator("This is a test" * 500, handle_long_generation="hole", max_new_tokens=20)
# Hole strategy cannot work
with self.assertRaises(ValueError):
text_generator(
"This is a test" * 500,
handle_long_generation="hole",
max_new_tokens=tokenizer.model_max_length + 10,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment