Unverified Commit 3a08dc63 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: better warnings with pipelines (#23128)

parent 2a16d8b2
......@@ -803,10 +803,12 @@ class Pipeline(_ScikitCompat):
self.torch_dtype = torch_dtype
self.binary_output = binary_output
# Update config with task specific parameters
# Update config and generation_config with task specific parameters
task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params:
self.model.config.update(task_specific_params.get(task))
if self.model.can_generate():
self.model.generation_config.update(**task_specific_params.get(task))
self.call_count = 0
self._batch_size = kwargs.pop("batch_size", None)
......
......@@ -273,7 +273,8 @@ class SummarizationPipeline(Text2TextGenerationPipeline):
if input_length < max_length:
logger.warning(
f"Your max_length is set to {max_length}, but your input_length is only {input_length}. You might "
f"Your max_length is set to {max_length}, but your input_length is only {input_length}. Since this is "
"a summarization task, where outputs shorter than the input are typically wanted, you might "
f"consider decreasing max_length manually, e.g. summarizer('...', max_length={input_length//2})"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment