Unverified Commit 53357e81 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding ValueError when imcompatible parameters are used. (#20729)

parent 5ba2dbd9
...@@ -130,6 +130,8 @@ class TextGenerationPipeline(Pipeline): ...@@ -130,6 +130,8 @@ class TextGenerationPipeline(Pipeline):
postprocess_params = {} postprocess_params = {}
if return_full_text is not None and return_type is None: if return_full_text is not None and return_type is None:
if return_text is not None:
raise ValueError("`return_text` is mutually exclusive with `return_full_text`")
return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
if return_tensors is not None and return_type is None: if return_tensors is not None and return_type is None:
return_type = ReturnType.TENSORS return_type = ReturnType.TENSORS
......
...@@ -201,6 +201,9 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM ...@@ -201,6 +201,9 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
], ],
) )
with self.assertRaises(ValueError):
outputs = text_generator("test", return_full_text=True, return_text=True)
# Empty prompt is slighly special # Empty prompt is slighly special
# it requires BOS token to exist. # it requires BOS token to exist.
# Special case for Pegasus which will always append EOS so will # Special case for Pegasus which will always append EOS so will
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment