Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f4f57f9d
Unverified
Commit
f4f57f9d
authored
Jan 16, 2024
by
Joao Gante
Committed by
GitHub
Jan 16, 2024
Browse files
Config: warning when saving generation kwargs in the model config (#28514)
parent
7142bdfa
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
107 additions
and
32 deletions
+107
-32
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+56
-27
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+7
-4
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+18
-0
tests/generation/test_framework_agnostic.py
tests/generation/test_framework_agnostic.py
+1
-1
tests/test_configuration_utils.py
tests/test_configuration_utils.py
+16
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+9
-0
No files found.
src/transformers/configuration_utils.py
View file @
f4f57f9d
...
...
@@ -277,6 +277,7 @@ class PretrainedConfig(PushToHubMixin):
self
.
tie_word_embeddings
=
kwargs
.
pop
(
"tie_word_embeddings"
,
True
)
# Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
self
.
chunk_size_feed_forward
=
kwargs
.
pop
(
"chunk_size_feed_forward"
,
0
)
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self
.
is_encoder_decoder
=
kwargs
.
pop
(
"is_encoder_decoder"
,
False
)
...
...
@@ -285,33 +286,10 @@ class PretrainedConfig(PushToHubMixin):
self
.
add_cross_attention
=
kwargs
.
pop
(
"add_cross_attention"
,
False
)
self
.
tie_encoder_decoder
=
kwargs
.
pop
(
"tie_encoder_decoder"
,
False
)
# Parameters for sequence generation
self
.
max_length
=
kwargs
.
pop
(
"max_length"
,
20
)
self
.
min_length
=
kwargs
.
pop
(
"min_length"
,
0
)
self
.
do_sample
=
kwargs
.
pop
(
"do_sample"
,
False
)
self
.
early_stopping
=
kwargs
.
pop
(
"early_stopping"
,
False
)
self
.
num_beams
=
kwargs
.
pop
(
"num_beams"
,
1
)
self
.
num_beam_groups
=
kwargs
.
pop
(
"num_beam_groups"
,
1
)
self
.
diversity_penalty
=
kwargs
.
pop
(
"diversity_penalty"
,
0.0
)
self
.
temperature
=
kwargs
.
pop
(
"temperature"
,
1.0
)
self
.
top_k
=
kwargs
.
pop
(
"top_k"
,
50
)
self
.
top_p
=
kwargs
.
pop
(
"top_p"
,
1.0
)
self
.
typical_p
=
kwargs
.
pop
(
"typical_p"
,
1.0
)
self
.
repetition_penalty
=
kwargs
.
pop
(
"repetition_penalty"
,
1.0
)
self
.
length_penalty
=
kwargs
.
pop
(
"length_penalty"
,
1.0
)
self
.
no_repeat_ngram_size
=
kwargs
.
pop
(
"no_repeat_ngram_size"
,
0
)
self
.
encoder_no_repeat_ngram_size
=
kwargs
.
pop
(
"encoder_no_repeat_ngram_size"
,
0
)
self
.
bad_words_ids
=
kwargs
.
pop
(
"bad_words_ids"
,
None
)
self
.
num_return_sequences
=
kwargs
.
pop
(
"num_return_sequences"
,
1
)
self
.
chunk_size_feed_forward
=
kwargs
.
pop
(
"chunk_size_feed_forward"
,
0
)
self
.
output_scores
=
kwargs
.
pop
(
"output_scores"
,
False
)
self
.
return_dict_in_generate
=
kwargs
.
pop
(
"return_dict_in_generate"
,
False
)
self
.
forced_bos_token_id
=
kwargs
.
pop
(
"forced_bos_token_id"
,
None
)
self
.
forced_eos_token_id
=
kwargs
.
pop
(
"forced_eos_token_id"
,
None
)
self
.
remove_invalid_values
=
kwargs
.
pop
(
"remove_invalid_values"
,
False
)
self
.
exponential_decay_length_penalty
=
kwargs
.
pop
(
"exponential_decay_length_penalty"
,
None
)
self
.
suppress_tokens
=
kwargs
.
pop
(
"suppress_tokens"
,
None
)
self
.
begin_suppress_tokens
=
kwargs
.
pop
(
"begin_suppress_tokens"
,
None
)
# Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
# parameters, saving them will be deprecated. In a distant future, we won't need to load them.
for
parameter_name
,
default_value
in
self
.
_get_generation_defaults
().
items
():
setattr
(
self
,
parameter_name
,
kwargs
.
pop
(
parameter_name
,
default_value
))
# Fine-tuning task arguments
self
.
architectures
=
kwargs
.
pop
(
"architectures"
,
None
)
...
...
@@ -463,6 +441,18 @@ class PretrainedConfig(PushToHubMixin):
if
os
.
path
.
isfile
(
save_directory
):
raise
AssertionError
(
f
"Provided path (
{
save_directory
}
) should be a directory, not a file"
)
non_default_generation_parameters
=
{}
for
parameter_name
,
default_value
in
self
.
_get_generation_defaults
().
items
():
if
hasattr
(
self
,
parameter_name
)
and
getattr
(
self
,
parameter_name
)
!=
default_value
:
non_default_generation_parameters
[
parameter_name
]
=
getattr
(
self
,
parameter_name
)
if
len
(
non_default_generation_parameters
)
>
0
:
logger
.
warning
(
"Some non-default generation parameters are set in the model config. These should go into a "
"GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) "
"instead. This warning will be raised to an exception in v4.41.
\n
"
f
"Non-default generation parameters:
{
str
(
non_default_generation_parameters
)
}
"
)
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
if
push_to_hub
:
...
...
@@ -1050,6 +1040,45 @@ class PretrainedConfig(PushToHubMixin):
cls
.
_auto_class
=
auto_class
@
staticmethod
def
_get_generation_defaults
()
->
Dict
[
str
,
Any
]:
return
{
"max_length"
:
20
,
"min_length"
:
0
,
"do_sample"
:
False
,
"early_stopping"
:
False
,
"num_beams"
:
1
,
"num_beam_groups"
:
1
,
"diversity_penalty"
:
0.0
,
"temperature"
:
1.0
,
"top_k"
:
50
,
"top_p"
:
1.0
,
"typical_p"
:
1.0
,
"repetition_penalty"
:
1.0
,
"length_penalty"
:
1.0
,
"no_repeat_ngram_size"
:
0
,
"encoder_no_repeat_ngram_size"
:
0
,
"bad_words_ids"
:
None
,
"num_return_sequences"
:
1
,
"output_scores"
:
False
,
"return_dict_in_generate"
:
False
,
"forced_bos_token_id"
:
None
,
"forced_eos_token_id"
:
None
,
"remove_invalid_values"
:
False
,
"exponential_decay_length_penalty"
:
None
,
"suppress_tokens"
:
None
,
"begin_suppress_tokens"
:
None
,
}
def
_has_non_default_generation_parameters
(
self
)
->
bool
:
"""
Whether or not this instance holds non-default generation parameters.
"""
for
parameter_name
,
default_value
in
self
.
_get_generation_defaults
().
items
():
if
hasattr
(
self
,
parameter_name
)
and
getattr
(
self
,
parameter_name
)
!=
default_value
:
return
True
return
False
def
get_configuration_file
(
configuration_files
:
List
[
str
])
->
str
:
"""
...
...
src/transformers/generation/utils.py
View file @
f4f57f9d
...
...
@@ -1274,11 +1274,14 @@ class GenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# t
wo
conditions must be met
# t
hree
conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) the user must have set generation parameters in the model config.
if
(
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
)
and
self
.
config
.
_has_non_default_generation_parameters
()
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
...
...
src/transformers/modeling_utils.py
View file @
f4f57f9d
...
...
@@ -2335,6 +2335,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if
not
_hf_peft_config_loaded
:
model_to_save
.
config
.
save_pretrained
(
save_directory
)
if
self
.
can_generate
():
# generation config built from the model config + the model config holds generation kwargs -> generate
# may revert to legacy behavior if the two don't match
if
(
model_to_save
.
generation_config
.
_from_model_config
and
model_to_save
.
config
.
_has_non_default_generation_parameters
()
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
model_to_save
.
config
)
if
new_generation_config
!=
model_to_save
.
generation_config
:
logger
.
warning
(
"Your generation config was originally created from the model config, but the model "
"config has changed since then. Unless you pass the `generation_config` argument to this "
"model's `generate` calls, they will revert to the legacy behavior where the base "
"`generate` parameterization is loaded from the model config instead. "
"To avoid this behavior and this warning, we recommend you to overwrite the generation "
"config model attribute before calling the model's `save_pretrained`, preferably also "
"removing any generation kwargs from the model config. This warning will be raised to an "
"exception in v4.41."
)
model_to_save
.
generation_config
.
save_pretrained
(
save_directory
)
if
_hf_peft_config_loaded
:
...
...
tests/generation/test_framework_agnostic.py
View file @
f4f57f9d
...
...
@@ -529,7 +529,7 @@ class GenerationIntegrationTestsMixin:
pixel_values
=
floats_tensor
((
2
,
3
,
30
,
30
))
model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2"
)
model
.
config
.
decoder
.
eos_token_id
=
None
model
.
generation_config
.
eos_token_id
=
None
if
is_pt
:
pixel_values
=
pixel_values
.
to
(
torch_device
)
model
=
model
.
to
(
torch_device
)
...
...
tests/test_configuration_utils.py
View file @
f4f57f9d
...
...
@@ -296,3 +296,19 @@ class ConfigTestUtils(unittest.TestCase):
old_transformers
.
configuration_utils
.
__version__
=
"v3.0.0"
old_configuration
=
old_transformers
.
models
.
auto
.
AutoConfig
.
from_pretrained
(
repo
)
self
.
assertEqual
(
old_configuration
.
hidden_size
,
768
)
def
test_saving_config_with_custom_generation_kwargs_raises_warning
(
self
):
config
=
BertConfig
(
min_length
=
3
)
# `min_length = 3` is a non-default generation kwarg
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
self
.
assertLogs
(
"transformers.configuration_utils"
,
level
=
"WARNING"
)
as
logs
:
config
.
save_pretrained
(
tmp_dir
)
self
.
assertEqual
(
len
(
logs
.
output
),
1
)
self
.
assertIn
(
"min_length"
,
logs
.
output
[
0
])
def
test_has_non_default_generation_parameters
(
self
):
config
=
BertConfig
()
self
.
assertFalse
(
config
.
_has_non_default_generation_parameters
())
config
=
BertConfig
(
min_length
=
3
)
self
.
assertTrue
(
config
.
_has_non_default_generation_parameters
())
config
=
BertConfig
(
min_length
=
0
)
# `min_length = 0` is a default generation kwarg
self
.
assertFalse
(
config
.
_has_non_default_generation_parameters
())
tests/test_modeling_utils.py
View file @
f4f57f9d
...
...
@@ -1230,6 +1230,15 @@ class ModelUtilsTest(TestCasePlus):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
new_model
.
parameters
()):
self
.
assertTrue
(
torch
.
equal
(
p1
,
p2
))
def
test_modifying_model_config_causes_warning_saving_generation_config
(
self
):
model
=
AutoModelForCausalLM
.
from_pretrained
(
"gpt2"
)
model
.
config
.
top_k
=
1
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
self
.
assertLogs
(
"transformers.modeling_utils"
,
level
=
"WARNING"
)
as
logs
:
model
.
save_pretrained
(
tmp_dir
)
self
.
assertEqual
(
len
(
logs
.
output
),
1
)
self
.
assertIn
(
"Your generation config was originally created from the model config"
,
logs
.
output
[
0
])
@
slow
@
require_torch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment