"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c603d099aa24410ec5a60c23794cc4a293d92850"
Unverified Commit 06dd5975 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix bug in warnings T5 pipelines (#3545)

parent 9de9ceb6
...@@ -1235,17 +1235,19 @@ class SummarizationPipeline(Pipeline): ...@@ -1235,17 +1235,19 @@ class SummarizationPipeline(Pipeline):
elif self.framework == "tf": elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy() input_length = tf.shape(inputs["input_ids"])[-1].numpy()
if input_length < self.model.config.min_length // 2: min_length = generate_kwargs.get("min_length", self.model.config.min_length)
if input_length < min_length // 2:
logger.warning( logger.warning(
"Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format( "Your min_length is set to {}, but you input_length is only {}. You might consider decreasing min_length manually, e.g. summarizer('...', min_length=10)".format(
self.model.config.min_length, input_length min_length, input_length
) )
) )
if input_length < self.model.config.max_length: max_length = generate_kwargs.get("max_length", self.model.config.max_length)
if input_length < max_length:
logger.warning( logger.warning(
"Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format( "Your max_length is set to {}, but you input_length is only {}. You might consider decreasing max_length manually, e.g. summarizer('...', max_length=50)".format(
self.model.config.max_length, input_length max_length, input_length
) )
) )
...@@ -1349,10 +1351,11 @@ class TranslationPipeline(Pipeline): ...@@ -1349,10 +1351,11 @@ class TranslationPipeline(Pipeline):
elif self.framework == "tf": elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy() input_length = tf.shape(inputs["input_ids"])[-1].numpy()
if input_length > 0.9 * self.model.config.max_length: max_length = generate_kwargs.get("max_length", self.model.config.max_length)
if input_length > 0.9 * max_length:
logger.warning( logger.warning(
"Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format( "Your input_length: {} is bigger than 0.9 * max_length: {}. You might consider increasing your max_length manually, e.g. translator('...', max_length=400)".format(
input_length, self.model.config.max_length input_length, max_length
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment