".github/vscode:/vscode.git/clone" did not exist on "a01fe4cd32d38f63a98ebfaf9c8912dfe6a4aa5e"
Unverified Commit 510270af authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: `GenerationConfig` throws an exception when `generate` args are passed (#27757)

parent fe41647a
...@@ -497,6 +497,24 @@ class GenerationConfig(PushToHubMixin): ...@@ -497,6 +497,24 @@ class GenerationConfig(PushToHubMixin):
f"({self.num_beams})." f"({self.num_beams})."
) )
# 5. check common issue: passing `generate` arguments inside the generation config
generate_arguments = (
"logits_processor",
"stopping_criteria",
"prefix_allowed_tokens_fn",
"synced_gpus",
"assistant_model",
"streamer",
"negative_prompt_ids",
"negative_prompt_attention_mask",
)
for arg in generate_arguments:
if hasattr(self, arg):
raise ValueError(
f"Argument `{arg}` is not a valid argument of `GenerationConfig`. It should be passed to "
"`generate()` (or a pipeline) directly."
)
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
......
...@@ -120,6 +120,34 @@ class GenerationConfigTest(unittest.TestCase): ...@@ -120,6 +120,34 @@ class GenerationConfigTest(unittest.TestCase):
self.assertEqual(loaded_config.do_sample, True) self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.num_beams, 1) # default value self.assertEqual(loaded_config.num_beams, 1) # default value
def test_validate(self):
"""
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
"""
# Case 1: A correct configuration will not throw any warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig()
self.assertEqual(len(captured_warnings), 0)
# Case 2: Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(temperature=0.5)
self.assertEqual(len(captured_warnings), 1)
# Case 3: Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(num_return_sequences=2)
# Case 4: Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(logits_processor="foo")
# Case 5: Model-specific parameters will NOT raise an exception or a warning
with warnings.catch_warnings(record=True) as captured_warnings:
GenerationConfig(foo="bar")
self.assertEqual(len(captured_warnings), 0)
def test_refuse_to_save(self): def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation.""" """Tests that we refuse to save a generation config that fails validation."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment