"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f497f564bb76697edab09184a252fc1b1a326d1e"
Unverified Commit f2c1df93 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`generate`] Only warn users if the `generation_config`'s `max_length` is set...

[`generate`]  Only warn users if the `generation_config`'s `max_length` is set to the default value (#25030)

* check max length is default

* nit

* update warning: no-longer deprecate

* comment in the configuration_utils in case max length's default gets changed in the futur
parent c879318c
...@@ -234,6 +234,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -234,6 +234,7 @@ class GenerationConfig(PushToHubMixin):
def __init__(self, **kwargs): def __init__(self, **kwargs):
# Parameters that control the length of the output # Parameters that control the length of the output
# if the default `max_length` is updated here, make sure to update the `generate` tests following https://github.com/huggingface/transformers/pull/25030
self.max_length = kwargs.pop("max_length", 20) self.max_length = kwargs.pop("max_length", 20)
self.max_new_tokens = kwargs.pop("max_new_tokens", None) self.max_new_tokens = kwargs.pop("max_new_tokens", None)
self.min_length = kwargs.pop("min_length", 0) self.min_length = kwargs.pop("min_length", 0)
......
...@@ -377,11 +377,11 @@ class FlaxGenerationMixin: ...@@ -377,11 +377,11 @@ class FlaxGenerationMixin:
# Prepare `max_length` depending on other stopping criteria. # Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1] input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None: if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20:
# 20 is the default max_length of the generation config
warnings.warn( warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.",
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning, UserWarning,
) )
elif generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
......
...@@ -850,11 +850,11 @@ class TFGenerationMixin: ...@@ -850,11 +850,11 @@ class TFGenerationMixin:
# 7. Prepare `max_length` depending on other stopping criteria. # 7. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = shape_list(input_ids)[-1] input_ids_seq_length = shape_list(input_ids)[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None: if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20:
# 20 is the default max_length of the generation config
warnings.warn( warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.",
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning, UserWarning,
) )
elif generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
......
...@@ -1365,11 +1365,11 @@ class GenerationMixin: ...@@ -1365,11 +1365,11 @@ class GenerationMixin:
# 6. Prepare `max_length` depending on other stopping criteria. # 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1] input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None: if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20:
# 20 is the default max_length of the generation config
warnings.warn( warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.",
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning, UserWarning,
) )
elif generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
......
...@@ -1303,11 +1303,10 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1303,11 +1303,10 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
# 5. Prepare `max_length` depending on other stopping criteria. # 5. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1] input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None: if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length != 20:
logger.warning( logger.warning(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.",
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning, UserWarning,
) )
elif generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
...@@ -2332,9 +2331,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2332,9 +2331,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None: if has_default_max_length and generation_config.max_new_tokens is None:
logger.warning( logger.warning(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " f"Using the model-agnostic default `max_length` (={generation_config.max_length}) "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation.",
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning, UserWarning,
) )
elif generation_config.max_new_tokens is not None: elif generation_config.max_new_tokens is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment