Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3319eb54
Unverified
Commit
3319eb54
authored
Sep 12, 2023
by
Joao Gante
Committed by
GitHub
Sep 12, 2023
Browse files
Generate: legacy mode is only triggered when `generation_config` is untouched (#25962)
parent
18abc756
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
53 additions
and
29 deletions
+53
-29
docs/source/en/generation_strategies.md
docs/source/en/generation_strategies.md
+1
-3
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+24
-10
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+9
-5
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+9
-5
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+9
-5
tests/generation/test_utils.py
tests/generation/test_utils.py
+1
-1
No files found.
docs/source/en/generation_strategies.md
View file @
3319eb54
...
...
@@ -55,12 +55,10 @@ When you load a model explicitly, you can inspect the generation configuration t
>>>
from
transformers
import
AutoModelForCausalLM
>>>
model
=
AutoModelForCausalLM
.
from_pretrained
(
"distilgpt2"
)
>>>
model
.
generation_config
# doctest: +IGNORE_RESULT
>>>
model
.
generation_config
GenerationConfig
{
"_from_model_config"
:
true
,
"bos_token_id"
:
50256
,
"eos_token_id"
:
50256
,
"transformers_version"
:
"4.26.0.dev0"
}
```
...
...
src/transformers/generation/configuration_utils.py
View file @
3319eb54
...
...
@@ -34,6 +34,7 @@ from ..utils import (
logger
=
logging
.
get_logger
(
__name__
)
METADATA_FIELDS
=
(
"_from_model_config"
,
"_commit_hash"
,
"_original_object_hash"
,
"transformers_version"
)
class
GenerationConfig
(
PushToHubMixin
):
...
...
@@ -315,20 +316,19 @@ class GenerationConfig(PushToHubMixin):
# Validate the values of the attributes
self
.
validate
(
is_init
=
True
)
def
__hash__
(
self
):
return
hash
(
self
.
to_json_string
(
ignore_metadata
=
True
))
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
GenerationConfig
):
return
False
self_dict
=
self
.
__dict__
.
copy
()
other_dict
=
other
.
__dict__
.
copy
()
# ignore metadata
for
metadata_field
in
(
"_from_model_config"
,
"_commit_hash"
,
"transformers_version"
):
self_dict
.
pop
(
metadata_field
,
None
)
other_dict
.
pop
(
metadata_field
,
None
)
return
self_dict
==
other_dict
self_without_metadata
=
self
.
to_json_string
(
use_diff
=
False
,
ignore_metadata
=
True
)
other_without_metadata
=
other
.
to_json_string
(
use_diff
=
False
,
ignore_metadata
=
True
)
return
self_without_metadata
==
other_without_metadata
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
(
ignore_metadata
=
True
)
}
"
def
validate
(
self
,
is_init
=
False
):
"""
...
...
@@ -729,7 +729,9 @@ class GenerationConfig(PushToHubMixin):
else
:
logger
.
info
(
f
"loading configuration file
{
configuration_file
}
from cache at
{
resolved_config_file
}
"
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
config
=
cls
.
from_dict
(
config_dict
,
**
kwargs
)
config
.
_original_object_hash
=
hash
(
config
)
# Hash to detect whether the instance was modified
return
config
@
classmethod
def
_dict_from_json_file
(
cls
,
json_file
:
Union
[
str
,
os
.
PathLike
]):
...
...
@@ -814,8 +816,12 @@ class GenerationConfig(PushToHubMixin):
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output
=
copy
.
deepcopy
(
self
.
__dict__
)
# Fields to ignore at serialization time
if
"_commit_hash"
in
output
:
del
output
[
"_commit_hash"
]
if
"_original_object_hash"
in
output
:
del
output
[
"_original_object_hash"
]
# Transformers version when serializing this file
output
[
"transformers_version"
]
=
__version__
...
...
@@ -823,7 +829,7 @@ class GenerationConfig(PushToHubMixin):
self
.
dict_torch_dtype_to_str
(
output
)
return
output
def
to_json_string
(
self
,
use_diff
:
bool
=
True
)
->
str
:
def
to_json_string
(
self
,
use_diff
:
bool
=
True
,
ignore_metadata
:
bool
=
False
)
->
str
:
"""
Serializes this instance to a JSON string.
...
...
@@ -831,6 +837,8 @@ class GenerationConfig(PushToHubMixin):
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
is serialized to JSON string.
ignore_metadata (`bool`, *optional*, defaults to `False`):
Whether to ignore the metadata fields present in the instance
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
...
...
@@ -839,6 +847,11 @@ class GenerationConfig(PushToHubMixin):
config_dict
=
self
.
to_diff_dict
()
else
:
config_dict
=
self
.
to_dict
()
if
ignore_metadata
:
for
metadata_field
in
METADATA_FIELDS
:
config_dict
.
pop
(
metadata_field
,
None
)
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
],
use_diff
:
bool
=
True
):
...
...
@@ -882,6 +895,7 @@ class GenerationConfig(PushToHubMixin):
if
attr
in
decoder_config
and
getattr
(
config
,
attr
)
==
getattr
(
default_generation_config
,
attr
):
setattr
(
config
,
attr
,
decoder_config
[
attr
])
config
.
_original_object_hash
=
hash
(
config
)
# Hash to detect whether the instance was modified
return
config
def
update
(
self
,
**
kwargs
):
...
...
src/transformers/generation/flax_utils.py
View file @
3319eb54
...
...
@@ -310,16 +310,20 @@ class FlaxGenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if
self
.
generation_config
.
_from_model_config
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration
file
(see"
" https://huggingface.co/docs/transformers/
main_classes/
text
_
generation )"
" Please use a
nd modify the model
generation configuration (see"
" https://huggingface.co/docs/transformers/
generation_strategies#default-
text
-
generation
-configuration
)"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
...
...
src/transformers/generation/tf_utils.py
View file @
3319eb54
...
...
@@ -716,16 +716,20 @@ class TFGenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if
self
.
generation_config
.
_from_model_config
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration
file
(see"
" https://huggingface.co/docs/transformers/
main_classes/
text
_
generation )"
" Please use a
nd modify the model
generation configuration (see"
" https://huggingface.co/docs/transformers/
generation_strategies#default-
text
-
generation
-configuration
)"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
...
...
src/transformers/generation/utils.py
View file @
3319eb54
...
...
@@ -1409,16 +1409,20 @@ class GenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if
self
.
generation_config
.
_from_model_config
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration
file
(see"
" https://huggingface.co/docs/transformers/
main_classes/
text
_
generation )"
" Please use a
nd modify the model
generation configuration (see"
" https://huggingface.co/docs/transformers/
generation_strategies#default-
text
-
generation
-configuration
)"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
...
...
tests/generation/test_utils.py
View file @
3319eb54
...
...
@@ -2880,7 +2880,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# Generation config max_length != 20 -> no warning
with
warnings
.
catch_warnings
(
record
=
True
)
as
warning_list
:
# generation_config is modified -> legacy mode is disabled = generation_config takes precedence
model
.
generation_config
.
max_length
=
10
model
.
generation_config
.
_from_model_config
=
False
# otherwise model.config.max_length=20 takes precedence
model
.
generate
(
input_ids
)
self
.
assertEqual
(
len
(
warning_list
),
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment