Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f4f57f9d
Unverified
Commit
f4f57f9d
authored
Jan 16, 2024
by
Joao Gante
Committed by
GitHub
Jan 16, 2024
Browse files
Config: warning when saving generation kwargs in the model config (#28514)
parent
7142bdfa
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
107 additions
and
32 deletions
+107
-32
src/transformers/configuration_utils.py
src/transformers/configuration_utils.py
+56
-27
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+7
-4
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+18
-0
tests/generation/test_framework_agnostic.py
tests/generation/test_framework_agnostic.py
+1
-1
tests/test_configuration_utils.py
tests/test_configuration_utils.py
+16
-0
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+9
-0
No files found.
src/transformers/configuration_utils.py
View file @
f4f57f9d
...
@@ -277,6 +277,7 @@ class PretrainedConfig(PushToHubMixin):
...
@@ -277,6 +277,7 @@ class PretrainedConfig(PushToHubMixin):
self
.
tie_word_embeddings
=
kwargs
.
pop
(
self
.
tie_word_embeddings
=
kwargs
.
pop
(
"tie_word_embeddings"
,
True
"tie_word_embeddings"
,
True
)
# Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
)
# Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
self
.
chunk_size_feed_forward
=
kwargs
.
pop
(
"chunk_size_feed_forward"
,
0
)
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self
.
is_encoder_decoder
=
kwargs
.
pop
(
"is_encoder_decoder"
,
False
)
self
.
is_encoder_decoder
=
kwargs
.
pop
(
"is_encoder_decoder"
,
False
)
...
@@ -285,33 +286,10 @@ class PretrainedConfig(PushToHubMixin):
...
@@ -285,33 +286,10 @@ class PretrainedConfig(PushToHubMixin):
self
.
add_cross_attention
=
kwargs
.
pop
(
"add_cross_attention"
,
False
)
self
.
add_cross_attention
=
kwargs
.
pop
(
"add_cross_attention"
,
False
)
self
.
tie_encoder_decoder
=
kwargs
.
pop
(
"tie_encoder_decoder"
,
False
)
self
.
tie_encoder_decoder
=
kwargs
.
pop
(
"tie_encoder_decoder"
,
False
)
# Parameters for sequence generation
# Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
self
.
max_length
=
kwargs
.
pop
(
"max_length"
,
20
)
# parameters, saving them will be deprecated. In a distant future, we won't need to load them.
self
.
min_length
=
kwargs
.
pop
(
"min_length"
,
0
)
for
parameter_name
,
default_value
in
self
.
_get_generation_defaults
().
items
():
self
.
do_sample
=
kwargs
.
pop
(
"do_sample"
,
False
)
setattr
(
self
,
parameter_name
,
kwargs
.
pop
(
parameter_name
,
default_value
))
self
.
early_stopping
=
kwargs
.
pop
(
"early_stopping"
,
False
)
self
.
num_beams
=
kwargs
.
pop
(
"num_beams"
,
1
)
self
.
num_beam_groups
=
kwargs
.
pop
(
"num_beam_groups"
,
1
)
self
.
diversity_penalty
=
kwargs
.
pop
(
"diversity_penalty"
,
0.0
)
self
.
temperature
=
kwargs
.
pop
(
"temperature"
,
1.0
)
self
.
top_k
=
kwargs
.
pop
(
"top_k"
,
50
)
self
.
top_p
=
kwargs
.
pop
(
"top_p"
,
1.0
)
self
.
typical_p
=
kwargs
.
pop
(
"typical_p"
,
1.0
)
self
.
repetition_penalty
=
kwargs
.
pop
(
"repetition_penalty"
,
1.0
)
self
.
length_penalty
=
kwargs
.
pop
(
"length_penalty"
,
1.0
)
self
.
no_repeat_ngram_size
=
kwargs
.
pop
(
"no_repeat_ngram_size"
,
0
)
self
.
encoder_no_repeat_ngram_size
=
kwargs
.
pop
(
"encoder_no_repeat_ngram_size"
,
0
)
self
.
bad_words_ids
=
kwargs
.
pop
(
"bad_words_ids"
,
None
)
self
.
num_return_sequences
=
kwargs
.
pop
(
"num_return_sequences"
,
1
)
self
.
chunk_size_feed_forward
=
kwargs
.
pop
(
"chunk_size_feed_forward"
,
0
)
self
.
output_scores
=
kwargs
.
pop
(
"output_scores"
,
False
)
self
.
return_dict_in_generate
=
kwargs
.
pop
(
"return_dict_in_generate"
,
False
)
self
.
forced_bos_token_id
=
kwargs
.
pop
(
"forced_bos_token_id"
,
None
)
self
.
forced_eos_token_id
=
kwargs
.
pop
(
"forced_eos_token_id"
,
None
)
self
.
remove_invalid_values
=
kwargs
.
pop
(
"remove_invalid_values"
,
False
)
self
.
exponential_decay_length_penalty
=
kwargs
.
pop
(
"exponential_decay_length_penalty"
,
None
)
self
.
suppress_tokens
=
kwargs
.
pop
(
"suppress_tokens"
,
None
)
self
.
begin_suppress_tokens
=
kwargs
.
pop
(
"begin_suppress_tokens"
,
None
)
# Fine-tuning task arguments
# Fine-tuning task arguments
self
.
architectures
=
kwargs
.
pop
(
"architectures"
,
None
)
self
.
architectures
=
kwargs
.
pop
(
"architectures"
,
None
)
...
@@ -463,6 +441,18 @@ class PretrainedConfig(PushToHubMixin):
...
@@ -463,6 +441,18 @@ class PretrainedConfig(PushToHubMixin):
if
os
.
path
.
isfile
(
save_directory
):
if
os
.
path
.
isfile
(
save_directory
):
raise
AssertionError
(
f
"Provided path (
{
save_directory
}
) should be a directory, not a file"
)
raise
AssertionError
(
f
"Provided path (
{
save_directory
}
) should be a directory, not a file"
)
non_default_generation_parameters
=
{}
for
parameter_name
,
default_value
in
self
.
_get_generation_defaults
().
items
():
if
hasattr
(
self
,
parameter_name
)
and
getattr
(
self
,
parameter_name
)
!=
default_value
:
non_default_generation_parameters
[
parameter_name
]
=
getattr
(
self
,
parameter_name
)
if
len
(
non_default_generation_parameters
)
>
0
:
logger
.
warning
(
"Some non-default generation parameters are set in the model config. These should go into a "
"GenerationConfig file (https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) "
"instead. This warning will be raised to an exception in v4.41.
\n
"
f
"Non-default generation parameters:
{
str
(
non_default_generation_parameters
)
}
"
)
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
if
push_to_hub
:
if
push_to_hub
:
...
@@ -1050,6 +1040,45 @@ class PretrainedConfig(PushToHubMixin):
...
@@ -1050,6 +1040,45 @@ class PretrainedConfig(PushToHubMixin):
cls
.
_auto_class
=
auto_class
cls
.
_auto_class
=
auto_class
@
staticmethod
def
_get_generation_defaults
()
->
Dict
[
str
,
Any
]:
return
{
"max_length"
:
20
,
"min_length"
:
0
,
"do_sample"
:
False
,
"early_stopping"
:
False
,
"num_beams"
:
1
,
"num_beam_groups"
:
1
,
"diversity_penalty"
:
0.0
,
"temperature"
:
1.0
,
"top_k"
:
50
,
"top_p"
:
1.0
,
"typical_p"
:
1.0
,
"repetition_penalty"
:
1.0
,
"length_penalty"
:
1.0
,
"no_repeat_ngram_size"
:
0
,
"encoder_no_repeat_ngram_size"
:
0
,
"bad_words_ids"
:
None
,
"num_return_sequences"
:
1
,
"output_scores"
:
False
,
"return_dict_in_generate"
:
False
,
"forced_bos_token_id"
:
None
,
"forced_eos_token_id"
:
None
,
"remove_invalid_values"
:
False
,
"exponential_decay_length_penalty"
:
None
,
"suppress_tokens"
:
None
,
"begin_suppress_tokens"
:
None
,
}
def
_has_non_default_generation_parameters
(
self
)
->
bool
:
"""
Whether or not this instance holds non-default generation parameters.
"""
for
parameter_name
,
default_value
in
self
.
_get_generation_defaults
().
items
():
if
hasattr
(
self
,
parameter_name
)
and
getattr
(
self
,
parameter_name
)
!=
default_value
:
return
True
return
False
def
get_configuration_file
(
configuration_files
:
List
[
str
])
->
str
:
def
get_configuration_file
(
configuration_files
:
List
[
str
])
->
str
:
"""
"""
...
...
src/transformers/generation/utils.py
View file @
f4f57f9d
...
@@ -1274,11 +1274,14 @@ class GenerationMixin:
...
@@ -1274,11 +1274,14 @@ class GenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# t
wo
conditions must be met
# t
hree
conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
# 2) the generation config must have seen no modification since its creation (the hash is the same);
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
# 3) the user must have set generation parameters in the model config.
self
.
generation_config
if
(
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
)
and
self
.
config
.
_has_non_default_generation_parameters
()
):
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
if
new_generation_config
!=
self
.
generation_config
:
...
...
src/transformers/modeling_utils.py
View file @
f4f57f9d
...
@@ -2335,6 +2335,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2335,6 +2335,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if
not
_hf_peft_config_loaded
:
if
not
_hf_peft_config_loaded
:
model_to_save
.
config
.
save_pretrained
(
save_directory
)
model_to_save
.
config
.
save_pretrained
(
save_directory
)
if
self
.
can_generate
():
if
self
.
can_generate
():
# generation config built from the model config + the model config holds generation kwargs -> generate
# may revert to legacy behavior if the two don't match
if
(
model_to_save
.
generation_config
.
_from_model_config
and
model_to_save
.
config
.
_has_non_default_generation_parameters
()
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
model_to_save
.
config
)
if
new_generation_config
!=
model_to_save
.
generation_config
:
logger
.
warning
(
"Your generation config was originally created from the model config, but the model "
"config has changed since then. Unless you pass the `generation_config` argument to this "
"model's `generate` calls, they will revert to the legacy behavior where the base "
"`generate` parameterization is loaded from the model config instead. "
"To avoid this behavior and this warning, we recommend you to overwrite the generation "
"config model attribute before calling the model's `save_pretrained`, preferably also "
"removing any generation kwargs from the model config. This warning will be raised to an "
"exception in v4.41."
)
model_to_save
.
generation_config
.
save_pretrained
(
save_directory
)
model_to_save
.
generation_config
.
save_pretrained
(
save_directory
)
if
_hf_peft_config_loaded
:
if
_hf_peft_config_loaded
:
...
...
tests/generation/test_framework_agnostic.py
View file @
f4f57f9d
...
@@ -529,7 +529,7 @@ class GenerationIntegrationTestsMixin:
...
@@ -529,7 +529,7 @@ class GenerationIntegrationTestsMixin:
pixel_values
=
floats_tensor
((
2
,
3
,
30
,
30
))
pixel_values
=
floats_tensor
((
2
,
3
,
30
,
30
))
model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2"
)
model
=
model_cls
.
from_pretrained
(
"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2"
)
model
.
config
.
decoder
.
eos_token_id
=
None
model
.
generation_config
.
eos_token_id
=
None
if
is_pt
:
if
is_pt
:
pixel_values
=
pixel_values
.
to
(
torch_device
)
pixel_values
=
pixel_values
.
to
(
torch_device
)
model
=
model
.
to
(
torch_device
)
model
=
model
.
to
(
torch_device
)
...
...
tests/test_configuration_utils.py
View file @
f4f57f9d
...
@@ -296,3 +296,19 @@ class ConfigTestUtils(unittest.TestCase):
...
@@ -296,3 +296,19 @@ class ConfigTestUtils(unittest.TestCase):
old_transformers
.
configuration_utils
.
__version__
=
"v3.0.0"
old_transformers
.
configuration_utils
.
__version__
=
"v3.0.0"
old_configuration
=
old_transformers
.
models
.
auto
.
AutoConfig
.
from_pretrained
(
repo
)
old_configuration
=
old_transformers
.
models
.
auto
.
AutoConfig
.
from_pretrained
(
repo
)
self
.
assertEqual
(
old_configuration
.
hidden_size
,
768
)
self
.
assertEqual
(
old_configuration
.
hidden_size
,
768
)
def
test_saving_config_with_custom_generation_kwargs_raises_warning
(
self
):
config
=
BertConfig
(
min_length
=
3
)
# `min_length = 3` is a non-default generation kwarg
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
self
.
assertLogs
(
"transformers.configuration_utils"
,
level
=
"WARNING"
)
as
logs
:
config
.
save_pretrained
(
tmp_dir
)
self
.
assertEqual
(
len
(
logs
.
output
),
1
)
self
.
assertIn
(
"min_length"
,
logs
.
output
[
0
])
def
test_has_non_default_generation_parameters
(
self
):
config
=
BertConfig
()
self
.
assertFalse
(
config
.
_has_non_default_generation_parameters
())
config
=
BertConfig
(
min_length
=
3
)
self
.
assertTrue
(
config
.
_has_non_default_generation_parameters
())
config
=
BertConfig
(
min_length
=
0
)
# `min_length = 0` is a default generation kwarg
self
.
assertFalse
(
config
.
_has_non_default_generation_parameters
())
tests/test_modeling_utils.py
View file @
f4f57f9d
...
@@ -1230,6 +1230,15 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -1230,6 +1230,15 @@ class ModelUtilsTest(TestCasePlus):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
new_model
.
parameters
()):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
new_model
.
parameters
()):
self
.
assertTrue
(
torch
.
equal
(
p1
,
p2
))
self
.
assertTrue
(
torch
.
equal
(
p1
,
p2
))
def
test_modifying_model_config_causes_warning_saving_generation_config
(
self
):
model
=
AutoModelForCausalLM
.
from_pretrained
(
"gpt2"
)
model
.
config
.
top_k
=
1
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
self
.
assertLogs
(
"transformers.modeling_utils"
,
level
=
"WARNING"
)
as
logs
:
model
.
save_pretrained
(
tmp_dir
)
self
.
assertEqual
(
len
(
logs
.
output
),
1
)
self
.
assertIn
(
"Your generation config was originally created from the model config"
,
logs
.
output
[
0
])
@
slow
@
slow
@
require_torch
@
require_torch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment