Unverified Commit 871ba71d authored by FredericOdermatt's avatar FredericOdermatt Committed by GitHub
Browse files

GenerationConfig validate both constraints and force_words_ids (#29163)

GenerationConfig validate both options for constrained decoding: constraints and force_words_ids
parent 3fcfbe75
......@@ -482,11 +482,11 @@ class GenerationConfig(PushToHubMixin):
# 3. detect incorrect paramaterization specific to advanced beam modes
else:
# constrained beam search
if self.constraints is not None:
if self.constraints is not None or self.force_words_ids is not None:
constrained_wrong_parameter_msg = (
"`constraints` is not `None`, triggering constrained beam search. However, `{flag_name}` is set "
"to `{flag_value}`, which is incompatible with this generation mode. Set `constraints=None` or "
"unset `{flag_name}` to continue." + fix_location
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, "
"`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set "
"`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location
)
if self.do_sample is True:
raise ValueError(
......
......@@ -156,6 +156,11 @@ class GenerationConfigTest(unittest.TestCase):
# Impossible sets of contraints/parameters will raise an exception
with self.assertRaises(ValueError):
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
with self.assertRaises(ValueError):
# dummy constraint
GenerationConfig(do_sample=True, num_beams=2, constraints=["dummy"])
with self.assertRaises(ValueError):
GenerationConfig(do_sample=True, num_beams=2, force_words_ids=[[[1, 2, 3]]])
# Passing `generate()`-only flags to `validate` will raise an exception
with self.assertRaises(ValueError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment