Unverified Commit 366c0327 authored by thedamnedrhino's avatar thedamnedrhino Committed by GitHub
Browse files

Tokenizer kwargs in textgeneration pipe (#28362)

* added args to the pipeline

* added test

* more sensical tests

* fixup

* docs

* typo
;

* docs

* made changes to support named args

* fixed test

* docs update

* styles

* docs

* docs
parent a573ac74
...@@ -216,6 +216,12 @@ array([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], ...@@ -216,6 +216,12 @@ array([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
</tf> </tf>
</frameworkcontent> </frameworkcontent>
<Tip>
Different pipelines support tokenizer arguments in their `__call__()` differently. `text-2-text-generation` pipelines support (i.e. pass on)
only `truncation`. `text-generation` pipelines support `max_length`, `truncation`, `padding` and `add_special_tokens`.
In `fill-mask` pipelines, tokenizer arguments can be passed in the `tokenizer_kwargs` argument (dictionary).
</Tip>
## Audio ## Audio
For audio tasks, you'll need a [feature extractor](main_classes/feature_extractor) to prepare your dataset for the model. The feature extractor is designed to extract features from raw audio data, and convert them into tensors. For audio tasks, you'll need a [feature extractor](main_classes/feature_extractor) to prepare your dataset for the model. The feature extractor is designed to extract features from raw audio data, and convert them into tensors.
......
...@@ -104,9 +104,20 @@ class TextGenerationPipeline(Pipeline): ...@@ -104,9 +104,20 @@ class TextGenerationPipeline(Pipeline):
handle_long_generation=None, handle_long_generation=None,
stop_sequence=None, stop_sequence=None,
add_special_tokens=False, add_special_tokens=False,
truncation=None,
padding=False,
max_length=None,
**generate_kwargs, **generate_kwargs,
): ):
preprocess_params = {"add_special_tokens": add_special_tokens} preprocess_params = {
"add_special_tokens": add_special_tokens,
"truncation": truncation,
"padding": padding,
"max_length": max_length,
}
if max_length is not None:
generate_kwargs["max_length"] = max_length
if prefix is not None: if prefix is not None:
preprocess_params["prefix"] = prefix preprocess_params["prefix"] = prefix
if prefix: if prefix:
...@@ -208,10 +219,23 @@ class TextGenerationPipeline(Pipeline): ...@@ -208,10 +219,23 @@ class TextGenerationPipeline(Pipeline):
return super().__call__(text_inputs, **kwargs) return super().__call__(text_inputs, **kwargs)
def preprocess( def preprocess(
self, prompt_text, prefix="", handle_long_generation=None, add_special_tokens=False, **generate_kwargs self,
prompt_text,
prefix="",
handle_long_generation=None,
add_special_tokens=False,
truncation=None,
padding=False,
max_length=None,
**generate_kwargs,
): ):
inputs = self.tokenizer( inputs = self.tokenizer(
prefix + prompt_text, padding=False, add_special_tokens=add_special_tokens, return_tensors=self.framework prefix + prompt_text,
return_tensors=self.framework,
truncation=truncation,
padding=padding,
max_length=max_length,
add_special_tokens=add_special_tokens,
) )
inputs["prompt_text"] = prompt_text inputs["prompt_text"] = prompt_text
......
...@@ -90,6 +90,22 @@ class TextGenerationPipelineTests(unittest.TestCase): ...@@ -90,6 +90,22 @@ class TextGenerationPipelineTests(unittest.TestCase):
{"generated_token_ids": ANY(list)}, {"generated_token_ids": ANY(list)},
], ],
) )
## -- test tokenizer_kwargs
test_str = "testing tokenizer kwargs. using truncation must result in a different generation."
output_str, output_str_with_truncation = (
text_generator(test_str, do_sample=False, return_full_text=False)[0]["generated_text"],
text_generator(
test_str,
do_sample=False,
return_full_text=False,
truncation=True,
max_length=3,
)[0]["generated_text"],
)
assert output_str != output_str_with_truncation # results must be different because one hd truncation
# -- what is the point of this test? padding is hardcoded False in the pipeline anyway
text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id
text_generator.tokenizer.pad_token = "<pad>" text_generator.tokenizer.pad_token = "<pad>"
outputs = text_generator( outputs = text_generator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment