Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3319eb54
Unverified
Commit
3319eb54
authored
Sep 12, 2023
by
Joao Gante
Committed by
GitHub
Sep 12, 2023
Browse files
Generate: legacy mode is only triggered when `generation_config` is untouched (#25962)
parent
18abc756
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
53 additions
and
29 deletions
+53
-29
docs/source/en/generation_strategies.md
docs/source/en/generation_strategies.md
+1
-3
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+24
-10
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+9
-5
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+9
-5
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+9
-5
tests/generation/test_utils.py
tests/generation/test_utils.py
+1
-1
No files found.
docs/source/en/generation_strategies.md
View file @
3319eb54
...
@@ -55,12 +55,10 @@ When you load a model explicitly, you can inspect the generation configuration t
...
@@ -55,12 +55,10 @@ When you load a model explicitly, you can inspect the generation configuration t
>>>
from
transformers
import
AutoModelForCausalLM
>>>
from
transformers
import
AutoModelForCausalLM
>>>
model
=
AutoModelForCausalLM
.
from_pretrained
(
"distilgpt2"
)
>>>
model
=
AutoModelForCausalLM
.
from_pretrained
(
"distilgpt2"
)
>>>
model
.
generation_config
# doctest: +IGNORE_RESULT
>>>
model
.
generation_config
GenerationConfig
{
GenerationConfig
{
"_from_model_config"
:
true
,
"bos_token_id"
:
50256
,
"bos_token_id"
:
50256
,
"eos_token_id"
:
50256
,
"eos_token_id"
:
50256
,
"transformers_version"
:
"4.26.0.dev0"
}
}
```
```
...
...
src/transformers/generation/configuration_utils.py
View file @
3319eb54
...
@@ -34,6 +34,7 @@ from ..utils import (
...
@@ -34,6 +34,7 @@ from ..utils import (
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
METADATA_FIELDS
=
(
"_from_model_config"
,
"_commit_hash"
,
"_original_object_hash"
,
"transformers_version"
)
class
GenerationConfig
(
PushToHubMixin
):
class
GenerationConfig
(
PushToHubMixin
):
...
@@ -315,20 +316,19 @@ class GenerationConfig(PushToHubMixin):
...
@@ -315,20 +316,19 @@ class GenerationConfig(PushToHubMixin):
# Validate the values of the attributes
# Validate the values of the attributes
self
.
validate
(
is_init
=
True
)
self
.
validate
(
is_init
=
True
)
def
__hash__
(
self
):
return
hash
(
self
.
to_json_string
(
ignore_metadata
=
True
))
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
GenerationConfig
):
if
not
isinstance
(
other
,
GenerationConfig
):
return
False
return
False
self_dict
=
self
.
__dict__
.
copy
()
self_without_metadata
=
self
.
to_json_string
(
use_diff
=
False
,
ignore_metadata
=
True
)
other_dict
=
other
.
__dict__
.
copy
()
other_without_metadata
=
other
.
to_json_string
(
use_diff
=
False
,
ignore_metadata
=
True
)
# ignore metadata
return
self_without_metadata
==
other_without_metadata
for
metadata_field
in
(
"_from_model_config"
,
"_commit_hash"
,
"transformers_version"
):
self_dict
.
pop
(
metadata_field
,
None
)
other_dict
.
pop
(
metadata_field
,
None
)
return
self_dict
==
other_dict
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
()
}
"
return
f
"
{
self
.
__class__
.
__name__
}
{
self
.
to_json_string
(
ignore_metadata
=
True
)
}
"
def
validate
(
self
,
is_init
=
False
):
def
validate
(
self
,
is_init
=
False
):
"""
"""
...
@@ -729,7 +729,9 @@ class GenerationConfig(PushToHubMixin):
...
@@ -729,7 +729,9 @@ class GenerationConfig(PushToHubMixin):
else
:
else
:
logger
.
info
(
f
"loading configuration file
{
configuration_file
}
from cache at
{
resolved_config_file
}
"
)
logger
.
info
(
f
"loading configuration file
{
configuration_file
}
from cache at
{
resolved_config_file
}
"
)
return
cls
.
from_dict
(
config_dict
,
**
kwargs
)
config
=
cls
.
from_dict
(
config_dict
,
**
kwargs
)
config
.
_original_object_hash
=
hash
(
config
)
# Hash to detect whether the instance was modified
return
config
@
classmethod
@
classmethod
def
_dict_from_json_file
(
cls
,
json_file
:
Union
[
str
,
os
.
PathLike
]):
def
_dict_from_json_file
(
cls
,
json_file
:
Union
[
str
,
os
.
PathLike
]):
...
@@ -814,8 +816,12 @@ class GenerationConfig(PushToHubMixin):
...
@@ -814,8 +816,12 @@ class GenerationConfig(PushToHubMixin):
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
"""
output
=
copy
.
deepcopy
(
self
.
__dict__
)
output
=
copy
.
deepcopy
(
self
.
__dict__
)
# Fields to ignore at serialization time
if
"_commit_hash"
in
output
:
if
"_commit_hash"
in
output
:
del
output
[
"_commit_hash"
]
del
output
[
"_commit_hash"
]
if
"_original_object_hash"
in
output
:
del
output
[
"_original_object_hash"
]
# Transformers version when serializing this file
# Transformers version when serializing this file
output
[
"transformers_version"
]
=
__version__
output
[
"transformers_version"
]
=
__version__
...
@@ -823,7 +829,7 @@ class GenerationConfig(PushToHubMixin):
...
@@ -823,7 +829,7 @@ class GenerationConfig(PushToHubMixin):
self
.
dict_torch_dtype_to_str
(
output
)
self
.
dict_torch_dtype_to_str
(
output
)
return
output
return
output
def
to_json_string
(
self
,
use_diff
:
bool
=
True
)
->
str
:
def
to_json_string
(
self
,
use_diff
:
bool
=
True
,
ignore_metadata
:
bool
=
False
)
->
str
:
"""
"""
Serializes this instance to a JSON string.
Serializes this instance to a JSON string.
...
@@ -831,6 +837,8 @@ class GenerationConfig(PushToHubMixin):
...
@@ -831,6 +837,8 @@ class GenerationConfig(PushToHubMixin):
use_diff (`bool`, *optional*, defaults to `True`):
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
is serialized to JSON string.
is serialized to JSON string.
ignore_metadata (`bool`, *optional*, defaults to `False`):
Whether to ignore the metadata fields present in the instance
Returns:
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
`str`: String containing all the attributes that make up this configuration instance in JSON format.
...
@@ -839,6 +847,11 @@ class GenerationConfig(PushToHubMixin):
...
@@ -839,6 +847,11 @@ class GenerationConfig(PushToHubMixin):
config_dict
=
self
.
to_diff_dict
()
config_dict
=
self
.
to_diff_dict
()
else
:
else
:
config_dict
=
self
.
to_dict
()
config_dict
=
self
.
to_dict
()
if
ignore_metadata
:
for
metadata_field
in
METADATA_FIELDS
:
config_dict
.
pop
(
metadata_field
,
None
)
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
return
json
.
dumps
(
config_dict
,
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
],
use_diff
:
bool
=
True
):
def
to_json_file
(
self
,
json_file_path
:
Union
[
str
,
os
.
PathLike
],
use_diff
:
bool
=
True
):
...
@@ -882,6 +895,7 @@ class GenerationConfig(PushToHubMixin):
...
@@ -882,6 +895,7 @@ class GenerationConfig(PushToHubMixin):
if
attr
in
decoder_config
and
getattr
(
config
,
attr
)
==
getattr
(
default_generation_config
,
attr
):
if
attr
in
decoder_config
and
getattr
(
config
,
attr
)
==
getattr
(
default_generation_config
,
attr
):
setattr
(
config
,
attr
,
decoder_config
[
attr
])
setattr
(
config
,
attr
,
decoder_config
[
attr
])
config
.
_original_object_hash
=
hash
(
config
)
# Hash to detect whether the instance was modified
return
config
return
config
def
update
(
self
,
**
kwargs
):
def
update
(
self
,
**
kwargs
):
...
...
src/transformers/generation/flax_utils.py
View file @
3319eb54
...
@@ -310,16 +310,20 @@ class FlaxGenerationMixin:
...
@@ -310,16 +310,20 @@ class FlaxGenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# model attribute accordingly, if it was created from the model config
# two conditions must be met
if
self
.
generation_config
.
_from_model_config
:
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration
file
(see"
" Please use a
nd modify the model
generation configuration (see"
" https://huggingface.co/docs/transformers/
main_classes/
text
_
generation )"
" https://huggingface.co/docs/transformers/
generation_strategies#default-
text
-
generation
-configuration
)"
)
)
self
.
generation_config
=
new_generation_config
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
generation_config
=
self
.
generation_config
...
...
src/transformers/generation/tf_utils.py
View file @
3319eb54
...
@@ -716,16 +716,20 @@ class TFGenerationMixin:
...
@@ -716,16 +716,20 @@ class TFGenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# model attribute accordingly, if it was created from the model config
# two conditions must be met
if
self
.
generation_config
.
_from_model_config
:
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration
file
(see"
" Please use a
nd modify the model
generation configuration (see"
" https://huggingface.co/docs/transformers/
main_classes/
text
_
generation )"
" https://huggingface.co/docs/transformers/
generation_strategies#default-
text
-
generation
-configuration
)"
)
)
self
.
generation_config
=
new_generation_config
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
generation_config
=
self
.
generation_config
...
...
src/transformers/generation/utils.py
View file @
3319eb54
...
@@ -1409,16 +1409,20 @@ class GenerationMixin:
...
@@ -1409,16 +1409,20 @@ class GenerationMixin:
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# model attribute accordingly, if it was created from the model config
# two conditions must be met
if
self
.
generation_config
.
_from_model_config
:
# 1) the generation config must have been created from the model config (`_from_model_config` field);
# 2) the generation config must have seen no modification since its creation (the hash is the same).
if
self
.
generation_config
.
_from_model_config
and
self
.
generation_config
.
_original_object_hash
==
hash
(
self
.
generation_config
):
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration
file
(see"
" Please use a
nd modify the model
generation configuration (see"
" https://huggingface.co/docs/transformers/
main_classes/
text
_
generation )"
" https://huggingface.co/docs/transformers/
generation_strategies#default-
text
-
generation
-configuration
)"
)
)
self
.
generation_config
=
new_generation_config
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
generation_config
=
self
.
generation_config
...
...
tests/generation/test_utils.py
View file @
3319eb54
...
@@ -2880,7 +2880,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2880,7 +2880,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# Generation config max_length != 20 -> no warning
# Generation config max_length != 20 -> no warning
with
warnings
.
catch_warnings
(
record
=
True
)
as
warning_list
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
warning_list
:
# generation_config is modified -> legacy mode is disabled = generation_config takes precedence
model
.
generation_config
.
max_length
=
10
model
.
generation_config
.
max_length
=
10
model
.
generation_config
.
_from_model_config
=
False
# otherwise model.config.max_length=20 takes precedence
model
.
generate
(
input_ids
)
model
.
generate
(
input_ids
)
self
.
assertEqual
(
len
(
warning_list
),
0
)
self
.
assertEqual
(
len
(
warning_list
),
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment