Unverified Commit 510270af authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: `GenerationConfig` throws an exception when `generate` args are passed (#27757)

parent fe41647a
......@@ -497,6 +497,24 @@ class GenerationConfig(PushToHubMixin):
f"({self.num_beams})."
)
# 5. check common issue: passing `generate` arguments inside the generation config
generate_arguments = (
"logits_processor",
"stopping_criteria",
"prefix_allowed_tokens_fn",
"synced_gpus",
"assistant_model",
"streamer",
"negative_prompt_ids",
"negative_prompt_attention_mask",
)
for arg in generate_arguments:
if hasattr(self, arg):
raise ValueError(
f"Argument `{arg}` is not a valid argument of `GenerationConfig`. It should be passed to "
"`generate()` (or a pipeline) directly."
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
......
......@@ -120,6 +120,34 @@ class GenerationConfigTest(unittest.TestCase):
self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.num_beams, 1) # default value
def test_validate(self):
"""
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
"""
# Case 1: A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig()
self.assertEqual(len(captured_warnings), 0)
# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(temperature=0.5)
self.assertEqual(len(captured_warnings), 1)
# Case 3: Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(num_return_sequences=2)
# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(logits_processor="foo")
# Case 5: Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0)
def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment