Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f4f57f9d
Unverified
Commit
f4f57f9d
authored
Jan 16, 2024
by
Joao Gante
Committed by
GitHub
Jan 16, 2024
Browse files
Config: warning when saving generation kwargs in the model config (#28514)
parent
7142bdfa
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
107 additions
and
32 deletions
+107
-32
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+56
-27
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+7
-4
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+18
-0
tests/generation/test_framework_agnostic.py
tests/generation/test_framework_agnostic.py
+1
-1
tests/test_configuration_utils.py
tests/test_configuration_utils.py
+16
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+9
-0
No files found.
src/transformers/configuration_utils.py
View file @
f4f57f9d
...
@@ -277,6 +277,7 @@ class PretrainedConfig(PushToHubMixin):
...
@@ -277,6 +277,7 @@ class PretrainedConfig(PushToHubMixin):
self
.
tie_word_embeddings
=
kwargs
.
pop
(
self
.
tie_word_embeddings
=
kwargs
.
pop
(
"tie_word_embeddings"
,
True
"tie_word_embeddings"
,
True
)
# Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
)
# Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
self
.
chunk_size_feed_forward
=
kwargs
.
pop
(
"chunk_size_feed_forward"
,
0
)
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self
.
is_encoder_decoder
=
kwargs
.
pop
(
"is_encoder_decoder"
,
False
)
self
.
is_encoder_decoder
=
kwargs
.
pop
(
"is_encoder_decoder"
,
False
)
...
@@ -285,33 +286,10 @@ class PretrainedConfig(PushToHubMixin):
...
@@ -285,33 +286,10 @@ class PretrainedConfig(PushToHubMixin):
self
.
add_cross_attention
=
kwargs
.
pop
(
"add_cross_attention"
,
False
)
self
.
add_cross_attention
=
kwargs
.
pop
(
"add_cross_attention"
,
False
)
self
.
tie_encoder_decoder
=
kwargs
.
pop
(
"tie_encoder_decoder"
,
False
)
self
.
tie_encoder_decoder
=
kwargs
.
pop
(
"tie_encoder_decoder"
,
False
)
# Parameters for sequence generation
# Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
self
.
max_length
=
kwargs
.
pop
(
"max_length"
,
20
)
# parameters, saving them will be deprecated. In a distant future, we won't need to load them.
self
.
min_length
=
kwargs
.
pop
(
"min_length"
,
0
)
for
parameter_name
,
default_value
in
self
.
_get_generation_defaults
().
items
():
self
.
do_sample
=
kwargs
.
pop
(
"do_sample"
,
False
)
setattr
(
self
,
parameter_name
,
kwargs
.
pop
(
parameter_name
,
default_value
))
self
.
early_stopping
=
kwargs
.
pop
(
"early_stopping"
,
False
)
self
.
num_beams
=
kwargs
.
pop
(
"num_beams"
,
1
)
self
.
num_beam_groups
=
kwargs
.
pop
(
"num_beam_groups"
,
1
)
self
.
diversity_penalty
=
kwargs
.
pop
(
"diversity_penalty"
,
0.0
)
self
.
temperature
=
kwargs
.
pop
(
"temperature"
,
1.0
)
self
.
top_k
=
kwargs
.
pop
(
"top_k"
,
50
)
self
.
top_p
=
kwargs
.
pop
(
"top_p"
,
1.0
)
self
.
typical_p
=
kwargs
.
pop
(
"typical_p"
,
1.0
)
self
.
repetition_penalty
=
kwargs
.
pop
(
"repetition_penalty"
,
1.0
)
self
.
length_penalty
=
kwargs
.
pop
(
"length_penalty"
,
1.0
)
self
.
no_repeat_ngram_size
=
kwargs
.
pop
(
"no_repeat_ngram_size"
,
0
)
self
.
encoder_no_repeat_ngram_size
=
kwargs
.
pop
(
"encoder_no_repeat_ngram_size"
,
0
)
self
.
bad_words_ids
=
kwargs
.
pop
(
"bad_words_ids"
,
None
)
self
.
num_return_sequences
=
kwargs
.
pop
(
"num_return_sequences"
,
1
)
self
.
chunk_size_feed_forward
=
kwargs
.
pop
(
"chunk_size_feed_forward"
,
0
)
self
.
output_scores
=
kwargs
.
pop
(
"output_scores"
,
False
)
self
.
return_dict_in_generate
=
kwargs
.
pop
(
"return_dict_in_generate"
,
False
)
self
.
forced_bos_token_id
=
kwargs
.
pop
(
"forced_bos_token_id"
,
None
)
self
.
forced_eos_token_id
=
kwargs
.
pop
(
"forced_eos_token_id"
,
None
)
self
.
remove_invalid_values
=
kwargs
.
pop
(
"remove_invalid_values"
,
False
)
self
.
exponential_decay_length_penalty
=
kwargs
.
pop
(
"exponential_decay_length_penalty"
,
None
)
self
.
suppress_tokens
=
kwargs
.
pop
(
"suppress_tokens"
,
None
)
self
.
begin_suppress_tokens
=
kwargs
.
pop
(
"begin_suppress_tokens"
,
None
)
# Fine-tuning task arguments
# Fine-tuning task arguments
self
.
architectures
=
kwargs
.
pop
(
"architectures"
,
None
)
self
.
architectures
=
kwargs
.
pop
(
"architectures"
,
None
)
...
@@ -463,6 +441,18 @@ class PretrainedConfig(PushToHubMixin):
...
@@ -463,6 +441,18 @@ class PretrainedConfig(PushToHubMixin):
if
os
.
path
.
isfile
(
save_directory
):
if
os
.
path
.
isfile
(
save_directory
):
raise
AssertionError
(
f
"Provided path (
{
save_directory
}
) should be a directory, not a file"
)
raise
AssertionError
(
f
"Provided path (
{
save_directory
}
) should be a directory, not a file"
)
non_default_generation_parameters
=
{}
for
parameter_name
,
default_value
in
self
.
_get_generation_defaults
().
items
():
if
hasattr
(
self
,
parameter_name
)
and
getattr
(
self
,
parameter_name
)
!=
default_value
:
non_default_generation_parameters
[
parameter_name
]
=
getattr
(
self
,
parameter_name
)
if
len
(
non_default_generation_parameters
)
>
0
:
logger
.
warning
(
"Some non-default generation parameters are set in the model config. These should go into a "
"GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) "
"instead. This warning will be raised to an exception in v4.41.
\n
"
f
"Non-default generation parameters:
{
str
(
non_default_generation_parameters
)
}
"
)
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
if
push_to_hub
:
if
push_to_hub
:
...
@@ -1050,6 +1040,45 @@ class PretrainedConfig(PushToHubMixin):
...
@@ -1050,6 +1040,45 @@ class PretrainedConfig(PushToHubMixin):
cls
.
_auto_class
=
auto_class
cls
.
_auto_class
=
auto_class
@
staticmethod
def
_get_generation_defaults
()
->
Dict
[
str
,
Any
]:
return
{
"max_length"
:
20
,
"min_length"
:
0
,
"do_sample"
:
False
,
"early_stopping"
:
False
,
"num_beams"
:
1
,
"num_beam_groups"
:
1
,
"diversity_penalty"
:
0.0
,
"temperature"
:
1.0
,
"top_k"
:
50
,
"top_p"
:
1.0
,
"typical_p"
:
1.0
,
"repetition_penalty"
:
1.0
,
"length_penalty"
:
1.0
,
"no_repeat_ngram_size"
:
0
,
"encoder_no_repeat_ngram_size"
:
0
,
"bad_words_ids"
:
None
,
"num_return_sequences"
:
1
,
"output_scores"
:
False
,
"return_dict_in_generate"
:
False
,
"forced_bos_token_id"
:
None
,
"forced_eos_token_id"
:
None
,
"remove_invalid_values"
:
False
,
"exponential_decay_length_penalty"
:
None
,
"suppress_tokens"
:
None
,
"begin_suppress_tokens"
:
None
,
}
def
_has_non_default_generation_parameters
(
self
)
->
bool
:
"""
Whether or not this instance holds non-default generation parameters.
"""
for
parameter_name
,
default_value
in
self
.
_get_generation_defaults
().
items
():
if
hasattr
(
self
,
parameter_name
)
and
getattr
(
self
,
parameter_name
)
!=
default_value
:
return
True
return
False
def
get_configuration_file
(
configuration_files
:
List
[
str
])
->
str
:
def
get_configuration_file
(
configuration_files
:
List
[
str
])
->
str
:
"""
"""
...
...
src/transformers/generation/utils.py
View file @
f4f57f9d
...
@@ -1274,11 +1274,14 @@ class GenerationMixin:
...
@@ -1274,11 +1274,14 @@ class GenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# t
wo
conditions must be met
# t
hree
conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
# 2) the generation config must have seen no modification since its creation (the hash is the same);
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
# 3) the user must have set generation parameters in the model config.
self
.
generation_config
if
(
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
)
and
self
.
config
.
_has_non_default_generation_parameters
()
):
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
if
new_generation_config
!=
self
.
generation_config
:
...
...
src/transformers/modeling_utils.py
View file @
f4f57f9d
...
@@ -2335,6 +2335,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2335,6 +2335,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if
not
_hf_peft_config_loaded
:
if
not
_hf_peft_config_loaded
:
model_to_save
.
config
.
save_pretrained
(
save_directory
)
model_to_save
.
config
.
save_pretrained
(
save_directory
)
if
self
.
can_generate
():
if
self
.
can_generate
():
# generation config built from the model config + the model config holds generation kwargs -> generate
# may revert to legacy behavior if the two don't match
if
(
model_to_save
.
generation_config
.
_from_model_config
and
model_to_save
.
config
.
_has_non_default_generation_parameters
()
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
model_to_save
.
config
)
if
new_generation_config
!=
model_to_save
.
generation_config
:
logger
.
warning
(
"Your generation config was originally created from the model config, but the model "
"config has changed since then. Unless you pass the `generation_config` argument to this "
"model's `generate` calls, they will revert to the legacy behavior where the base "
"`generate` parameterization is loaded from the model config instead. "
"To avoid this behavior and this warning, we recommend you to overwrite the generation "
"config model attribute before calling the model's `save_pretrained`, preferably also "
"removing any generation kwargs from the model config. This warning will be raised to an "
"exception in v4.41."
)
model_to_save
.
generation_config
.
save_pretrained
(
save_directory
)
model_to_save
.
generation_config
.
save_pretrained
(
save_directory
)
if
_hf_peft_config_loaded
:
if
_hf_peft_config_loaded
:
...
...
tests/generation/test_framework_agnostic.py
View file @
f4f57f9d
...
@@ -529,7 +529,7 @@ class GenerationIntegrationTestsMixin:
...
@@ -529,7 +529,7 @@ class GenerationIntegrationTestsMixin:
pixel_values
=
floats_tensor
((
2
,
3
,
30
,
30
))
pixel_values
=
floats_tensor
((
2
,
3
,
30
,
30
))
model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2"
)
model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2"
)
model
.
config
.
decoder
.
eos_token_id
=
None
model
.
generation_config
.
eos_token_id
=
None
if
is_pt
:
if
is_pt
:
pixel_values
=
pixel_values
.
to
(
torch_device
)
pixel_values
=
pixel_values
.
to
(
torch_device
)
model
=
model
.
to
(
torch_device
)
model
=
model
.
to
(
torch_device
)
...
...
tests/test_configuration_utils.py
View file @
f4f57f9d
...
@@ -296,3 +296,19 @@ class ConfigTestUtils(unittest.TestCase):
...
@@ -296,3 +296,19 @@ class ConfigTestUtils(unittest.TestCase):
old_transformers
.
configuration_utils
.
__version__
=
"v3.0.0"
old_transformers
.
configuration_utils
.
__version__
=
"v3.0.0"
old_configuration
=
old_transformers
.
models
.
auto
.
AutoConfig
.
from_pretrained
(
repo
)
old_configuration
=
old_transformers
.
models
.
auto
.
AutoConfig
.
from_pretrained
(
repo
)
self
.
assertEqual
(
old_configuration
.
hidden_size
,
768
)
self
.
assertEqual
(
old_configuration
.
hidden_size
,
768
)
def
test_saving_config_with_custom_generation_kwargs_raises_warning
(
self
):
config
=
BertConfig
(
min_length
=
3
)
# `min_length = 3` is a non-default generation kwarg
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
self
.
assertLogs
(
"transformers.configuration_utils"
,
level
=
"WARNING"
)
as
logs
:
config
.
save_pretrained
(
tmp_dir
)
self
.
assertEqual
(
len
(
logs
.
output
),
1
)
self
.
assertIn
(
"min_length"
,
logs
.
output
[
0
])
def
test_has_non_default_generation_parameters
(
self
):
config
=
BertConfig
()
self
.
assertFalse
(
config
.
_has_non_default_generation_parameters
())
config
=
BertConfig
(
min_length
=
3
)
self
.
assertTrue
(
config
.
_has_non_default_generation_parameters
())
config
=
BertConfig
(
min_length
=
0
)
# `min_length = 0` is a default generation kwarg
self
.
assertFalse
(
config
.
_has_non_default_generation_parameters
())
tests/test_modeling_utils.py
View file @
f4f57f9d
...
@@ -1230,6 +1230,15 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -1230,6 +1230,15 @@ class ModelUtilsTest(TestCasePlus):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
new_model
.
parameters
()):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
new_model
.
parameters
()):
self
.
assertTrue
(
torch
.
equal
(
p1
,
p2
))
self
.
assertTrue
(
torch
.
equal
(
p1
,
p2
))
def
test_modifying_model_config_causes_warning_saving_generation_config
(
self
):
model
=
AutoModelForCausalLM
.
from_pretrained
(
"gpt2"
)
model
.
config
.
top_k
=
1
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
self
.
assertLogs
(
"transformers.modeling_utils"
,
level
=
"WARNING"
)
as
logs
:
model
.
save_pretrained
(
tmp_dir
)
self
.
assertEqual
(
len
(
logs
.
output
),
1
)
self
.
assertIn
(
"Your generation config was originally created from the model config"
,
logs
.
output
[
0
])
@
slow
@
slow
@
require_torch
@
require_torch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment