Unverified Commit b0bf3011 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: min length can't be larger than max length (#16668)

* min length must be smaller than max length

* Update min_length in tests
parent 4868a830
......@@ -259,6 +259,7 @@ class FlaxGenerationMixin:
```"""
# set init values
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
......@@ -269,6 +270,11 @@ class FlaxGenerationMixin:
if decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
)
if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
......@@ -389,7 +395,6 @@ class FlaxGenerationMixin:
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
min_length = min_length if min_length is not None else self.config.min_length
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
forced_bos_token_id = (
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
......
......@@ -1489,6 +1489,11 @@ class TFGenerationMixin:
if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
pad_token_id = eos_token_id
if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
)
# 2. Define model inputs
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
......
......@@ -700,7 +700,6 @@ class GenerationMixin:
else self.config.encoder_no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
min_length = min_length if min_length is not None else self.config.min_length
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
forced_bos_token_id = (
......@@ -1185,7 +1184,13 @@ class GenerationMixin:
)
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
)
if input_ids_seq_length >= max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
......
......@@ -102,7 +102,7 @@ class GenerationTesterMixin:
diversity_penalty=None,
):
process_kwargs = {
"min_length": input_length + 1,
"min_length": input_length + 1 if max_length is None else max_length - 1,
"bad_words_ids": [[1, 0]],
"no_repeat_ngram_size": 2,
"repetition_penalty": 1.2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment