"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9b5a6450d481b0f02834684ffd8b3ba4cbbd6fe0"
Unverified Commit 03e309d5 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Text2text pipeline: don't parameterize from the config (#26118)

parent 4fb64e28
...@@ -181,9 +181,11 @@ class Text2TextGenerationPipeline(Pipeline): ...@@ -181,9 +181,11 @@ class Text2TextGenerationPipeline(Pipeline):
elif self.framework == "tf": elif self.framework == "tf":
in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy() in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy()
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length) self.check_inputs(
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length) input_length,
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) generate_kwargs.get("min_length", self.model.config.min_length),
generate_kwargs.get("max_length", self.model.config.max_length),
)
output_ids = self.model.generate(**model_inputs, **generate_kwargs) output_ids = self.model.generate(**model_inputs, **generate_kwargs)
out_b = output_ids.shape[0] out_b = output_ids.shape[0]
if self.framework == "pt": if self.framework == "pt":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment