Unverified Commit 871ba71d authored by FredericOdermatt's avatar FredericOdermatt Committed by GitHub
Browse files

GenerationConfig validate both constraints and force_words_ids (#29163)

GenerationConfig validate both options for constrained decoding: constraints and force_words_ids
parent 3fcfbe75
...@@ -482,11 +482,11 @@ class GenerationConfig(PushToHubMixin): ...@@ -482,11 +482,11 @@ class GenerationConfig(PushToHubMixin):
# 3. detect incorrect paramaterization specific to advanced beam modes # 3. detect incorrect paramaterization specific to advanced beam modes
else: else:
# constrained beam search # constrained beam search
if self.constraints is not None: if self.constraints is not None or self.force_words_ids is not None:
constrained_wrong_parameter_msg = ( constrained_wrong_parameter_msg = (
"`constraints` is not `None`, triggering constrained beam search. However, `{flag_name}` is set " "one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, "
"to `{flag_value}`, which is incompatible with this generation mode. Set `constraints=None` or " "`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set "
"unset `{flag_name}` to continue." + fix_location "`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location
) )
if self.do_sample is True: if self.do_sample is True:
raise ValueError( raise ValueError(
......
...@@ -156,6 +156,11 @@ class GenerationConfigTest(unittest.TestCase): ...@@ -156,6 +156,11 @@ class GenerationConfigTest(unittest.TestCase):
# Impossible sets of contraints/parameters will raise an exception # Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2) GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
with self.assertRaises(ValueError):
# dummy constraint
GenerationConfig(do_sample=True, num_beams=2, constraints=["dummy"])
with self.assertRaises(ValueError):
GenerationConfig(do_sample=True, num_beams=2, force_words_ids=[[[1, 2, 3]]])
# Passing `generate()`-only flags to `validate` will raise an exception # Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment