Unverified Commit deb7605a authored by Yen Ting's avatar Yen Ting Committed by GitHub
Browse files

Prevent `TextGenerationPipeline._sanitize_parameters` from overriding...


Prevent `TextGenerationPipeline._sanitize_parameters` from overriding previously provided parameters (#30362)

* Fixed TextGenerationPipeline._sanitize_parameters default params

* removed empty spaces

---------
Co-authored-by: default avatarNg, Yen Ting <yen.ting.ng@intel.com>
parent d0c72c15
...@@ -129,19 +129,25 @@ class TextGenerationPipeline(Pipeline): ...@@ -129,19 +129,25 @@ class TextGenerationPipeline(Pipeline):
prefix=None, prefix=None,
handle_long_generation=None, handle_long_generation=None,
stop_sequence=None, stop_sequence=None,
add_special_tokens=False,
truncation=None, truncation=None,
padding=False,
max_length=None, max_length=None,
**generate_kwargs, **generate_kwargs,
): ):
preprocess_params = { preprocess_params = {}
"add_special_tokens": add_special_tokens,
"truncation": truncation, add_special_tokens = False
"padding": padding, if "add_special_tokens" in generate_kwargs:
"max_length": max_length, preprocess_params["add_special_tokens"] = generate_kwargs["add_special_tokens"]
} add_special_tokens = generate_kwargs["add_special_tokens"]
if "padding" in generate_kwargs:
preprocess_params["padding"] = generate_kwargs["padding"]
if truncation is not None:
preprocess_params["truncation"] = truncation
if max_length is not None: if max_length is not None:
preprocess_params["max_length"] = max_length
generate_kwargs["max_length"] = max_length generate_kwargs["max_length"] = max_length
if prefix is not None: if prefix is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment