Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3319eb54
Unverified
Commit
3319eb54
authored
Sep 12, 2023
by
Joao Gante
Committed by
GitHub
Sep 12, 2023
Browse files
Generate: legacy mode is only triggered when `generation_config` is untouched (#25962)
parent
18abc756
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
53 additions
and
29 deletions
+53
-29
docs/source/en/generation_strategies.md
docs/source/en/generation_strategies.md
+1
-3
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+24
-10
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+9
-5
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+9
-5
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+9
-5
tests/generation/test_utils.py
tests/generation/test_utils.py
+1
-1
No files found.
docs/source/en/generation_strategies.md
View file @
3319eb54
...
...
@@ -55,12 +55,10 @@ When you load a model explicitly, you can inspect the generation configuration t
>>>
from
transformers
import
AutoModelForCausalLM
>>>
model
=
AutoModelForCausalLM
.
from_pretrained
(
"distilgpt2"
)
>>>
model
.
generation_config
# doctest: +IGNORE_RESULT
>>>
model
.
generation_config
GenerationConfig
{
"_from_model_config"
:
true
,
"bos_token_id"
:
50256
,
"eos_token_id"
:
50256
,
"transformers_version"
:
"4.26.0.dev0"
}
```
...
...
src/transformers/generation/configuration_utils.py
View file @
3319eb54
...
...
@@ -34,6 +34,7 @@ from ..utils import (
logger
=
logging
.
get_logger
(
__name__
)
METADATA_FIELDS
=
(
"_from_model_config"
,
"_commit_hash"
,
"_original_object_hash"
,
"transformers_version"
)
class
GenerationConfig
(
PushToHubMixin
):
...
...
@@ -315,20 +316,19 @@ class GenerationConfig(PushToHubMixin):
# Validate the values of the attributes
self
.
validate
(
is_init
=
True
)
def
__hash__
(
self
):
return
hash
(
self
.
to_json_string
(
ignore_metadata
=
True
))
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
GenerationConfig
):
return
False
self_dict
=
self
.
__dict__
.
copy
()
other_dict
=
other
.
__dict__
.
copy
()
# ignore metadata
for
metadata_field
in
(
"_from_model_config"
,
"_commit_hash"
,
"transformers_version"
):
self_dict
.
pop
(
metadata_field
,
None
)
other_dict
.
pop
(
metadata_field
,
None
)
return
self_dict
==
other_dict
self_without_metadata
=
self
.
to_json_string
(
use_diff
=
False
,
ignore_metadata
=
True
)
other_without_metadata
=
other
.
to_json_string
(
use_diff
=
False
,
ignore_metadata
=
True
)
return
self_without_metadata
==
other_without_metadata
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
(
ignore_metadata
=
True
)
}
"
def
validate
(
self
,
is_init
=
False
):
"""
...
...
@@ -729,7 +729,9 @@ class GenerationConfig(PushToHubMixin):
else
:
logger
.
info
(
f
"loading configuration file
{
configuration_file
}
from cache at
{
resolved_config_file
}
"
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
config
=
cls
.
from_dict
(
config_dict
,
**
kwargs
)
config
.
_original_object_hash
=
hash
(
config
)
# Hash to detect whether the instance was modified
return
config
@
classmethod
def
_dict_from_json_file
(
cls
,
json_file
:
Union
[
str
,
os
.
PathLike
]):
...
...
@@ -814,8 +816,12 @@ class GenerationConfig(PushToHubMixin):
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output
=
copy
.
deepcopy
(
self
.
__dict__
)
# Fields to ignore at serialization time
if
"_commit_hash"
in
output
:
del
output
[
"_commit_hash"
]
if
"_original_object_hash"
in
output
:
del
output
[
"_original_object_hash"
]
# Transformers version when serializing this file
output
[
"transformers_version"
]
=
__version__
...
...
@@ -823,7 +829,7 @@ class GenerationConfig(PushToHubMixin):
self
.
dict_torch_dtype_to_str
(
output
)
return
output
def
to_json_string
(
self
,
use_diff
:
bool
=
True
)
->
str
:
def
to_json_string
(
self
,
use_diff
:
bool
=
True
,
ignore_metadata
:
bool
=
False
)
->
str
:
"""
Serializes this instance to a JSON string.
...
...
@@ -831,6 +837,8 @@ class GenerationConfig(PushToHubMixin):
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
is serialized to JSON string.
ignore_metadata (`bool`, *optional*, defaults to `False`):
Whether to ignore the metadata fields present in the instance
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
...
...
@@ -839,6 +847,11 @@ class GenerationConfig(PushToHubMixin):
config_dict
=
self
.
to_diff_dict
()
else
:
config_dict
=
self
.
to_dict
()
if
ignore_metadata
:
for
metadata_field
in
METADATA_FIELDS
:
config_dict
.
pop
(
metadata_field
,
None
)
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
],
use_diff
:
bool
=
True
):
...
...
@@ -882,6 +895,7 @@ class GenerationConfig(PushToHubMixin):
if
attr
in
decoder_config
and
getattr
(
config
,
attr
)
==
getattr
(
default_generation_config
,
attr
):
setattr
(
config
,
attr
,
decoder_config
[
attr
])
config
.
_original_object_hash
=
hash
(
config
)
# Hash to detect whether the instance was modified
return
config
def
update
(
self
,
**
kwargs
):
...
...
src/transformers/generation/flax_utils.py
View file @
3319eb54
...
...
@@ -310,16 +310,20 @@ class FlaxGenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if
self
.
generation_config
.
_from_model_config
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration
file
(see"
" https://huggingface.co/docs/transformers/
main_classes/
text
_
generation )"
" Please use a
nd modify the model
generation configuration (see"
" https://huggingface.co/docs/transformers/
generation_strategies#default-
text
-
generation
-configuration
)"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
...
...
src/transformers/generation/tf_utils.py
View file @
3319eb54
...
...
@@ -716,16 +716,20 @@ class TFGenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if
self
.
generation_config
.
_from_model_config
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration
file
(see"
" https://huggingface.co/docs/transformers/
main_classes/
text
_
generation )"
" Please use a
nd modify the model
generation configuration (see"
" https://huggingface.co/docs/transformers/
generation_strategies#default-
text
-
generation
-configuration
)"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
...
...
src/transformers/generation/utils.py
View file @
3319eb54
...
...
@@ -1409,16 +1409,20 @@ class GenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if
self
.
generation_config
.
_from_model_config
:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# two conditions must be met
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration
file
(see"
" https://huggingface.co/docs/transformers/
main_classes/
text
_
generation )"
" Please use a
nd modify the model
generation configuration (see"
" https://huggingface.co/docs/transformers/
generation_strategies#default-
text
-
generation
-configuration
)"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
...
...
tests/generation/test_utils.py
View file @
3319eb54
...
...
@@ -2880,7 +2880,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# Generation config max_length != 20 -> no warning
with
warnings
.
catch_warnings
(
record
=
True
)
as
warning_list
:
# generation_config is modified -> legacy mode is disabled = generation_config takes precedence
model
.
generation_config
.
max_length
=
10
model
.
generation_config
.
_from_model_config
=
False
# otherwise model.config.max_length=20 takes precedence
model
.
generate
(
input_ids
)
self
.
assertEqual
(
len
(
warning_list
),
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment