Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1eda4a41
Unverified
Commit
1eda4a41
authored
Jan 23, 2023
by
Joao Gante
Committed by
GitHub
Jan 23, 2023
Browse files
Generate: save generation config with the models' `.save_pretrained()` (#21264)
parent
cf1a1eed
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
117 additions
and
3 deletions
+117
-3
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+2
-0
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+2
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+2
-0
tests/generation/test_configuration_utils.py
tests/generation/test_configuration_utils.py
+79
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+9
-0
tests/test_modeling_flax_common.py
tests/test_modeling_flax_common.py
+8
-1
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+15
-1
No files found.
src/transformers/modeling_flax_utils.py
View file @
1eda4a41
...
...
@@ -1032,6 +1032,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
custom_object_save
(
self
,
save_directory
,
config
=
self
.
config
)
self
.
config
.
save_pretrained
(
save_directory
)
if
self
.
can_generate
():
self
.
generation_config
.
save_pretrained
(
save_directory
)
# save model
output_model_file
=
os
.
path
.
join
(
save_directory
,
FLAX_WEIGHTS_NAME
)
...
...
src/transformers/modeling_tf_utils.py
View file @
1eda4a41
...
...
@@ -2306,6 +2306,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
custom_object_save
(
self
,
save_directory
,
config
=
self
.
config
)
self
.
config
.
save_pretrained
(
save_directory
)
if
self
.
can_generate
():
self
.
generation_config
.
save_pretrained
(
save_directory
)
# If we save using the predefined names, we can load using `from_pretrained`
weights_name
=
SAFE_WEIGHTS_NAME
if
safe_serialization
else
TF2_WEIGHTS_NAME
...
...
src/transformers/modeling_utils.py
View file @
1eda4a41
...
...
@@ -1655,6 +1655,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Save the config
if
is_main_process
:
model_to_save
.
config
.
save_pretrained
(
save_directory
)
if
self
.
can_generate
():
model_to_save
.
generation_config
.
save_pretrained
(
save_directory
)
# Save the model
if
state_dict
is
None
:
...
...
tests/generation/test_configuration_utils.py
View file @
1eda4a41
...
...
@@ -17,11 +17,14 @@ import copy
import
tempfile
import
unittest
from
huggingface_hub
import
HfFolder
,
delete_repo
,
set_access_token
from
parameterized
import
parameterized
from
requests.exceptions
import
HTTPError
from
transformers
import
AutoConfig
,
GenerationConfig
from
transformers.testing_utils
import
TOKEN
,
USER
,
is_staging_test
class
LogitsProcessor
Test
(
unittest
.
TestCase
):
class
GenerationConfig
Test
(
unittest
.
TestCase
):
@
parameterized
.
expand
([(
None
,),
(
"foo.json"
,)])
def
test_save_load_config
(
self
,
config_name
):
config
=
GenerationConfig
(
...
...
@@ -74,3 +77,78 @@ class LogitsProcessorTest(unittest.TestCase):
# `.update()` returns a dictionary of unused kwargs
self
.
assertEqual
(
unused_kwargs
,
{
"foo"
:
"bar"
})
@
is_staging_test
class
ConfigPushToHubTester
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
_token
=
TOKEN
set_access_token
(
TOKEN
)
HfFolder
.
save_token
(
TOKEN
)
@
classmethod
def
tearDownClass
(
cls
):
try
:
delete_repo
(
token
=
cls
.
_token
,
repo_id
=
"test-generation-config"
)
except
HTTPError
:
pass
try
:
delete_repo
(
token
=
cls
.
_token
,
repo_id
=
"valid_org/test-generation-config-org"
)
except
HTTPError
:
pass
def
test_push_to_hub
(
self
):
config
=
GenerationConfig
(
do_sample
=
True
,
temperature
=
0.7
,
length_penalty
=
1.0
,
)
config
.
push_to_hub
(
"test-generation-config"
,
use_auth_token
=
self
.
_token
)
new_config
=
GenerationConfig
.
from_pretrained
(
f
"
{
USER
}
/test-generation-config"
)
for
k
,
v
in
config
.
to_dict
().
items
():
if
k
!=
"transformers_version"
:
self
.
assertEqual
(
v
,
getattr
(
new_config
,
k
))
# Reset repo
delete_repo
(
token
=
self
.
_token
,
repo_id
=
"test-generation-config"
)
# Push to hub via save_pretrained
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
config
.
save_pretrained
(
tmp_dir
,
repo_id
=
"test-generation-config"
,
push_to_hub
=
True
,
use_auth_token
=
self
.
_token
)
new_config
=
GenerationConfig
.
from_pretrained
(
f
"
{
USER
}
/test-generation-config"
)
for
k
,
v
in
config
.
to_dict
().
items
():
if
k
!=
"transformers_version"
:
self
.
assertEqual
(
v
,
getattr
(
new_config
,
k
))
def
test_push_to_hub_in_organization
(
self
):
config
=
GenerationConfig
(
do_sample
=
True
,
temperature
=
0.7
,
length_penalty
=
1.0
,
)
config
.
push_to_hub
(
"valid_org/test-generation-config-org"
,
use_auth_token
=
self
.
_token
)
new_config
=
GenerationConfig
.
from_pretrained
(
"valid_org/test-generation-config-org"
)
for
k
,
v
in
config
.
to_dict
().
items
():
if
k
!=
"transformers_version"
:
self
.
assertEqual
(
v
,
getattr
(
new_config
,
k
))
# Reset repo
delete_repo
(
token
=
self
.
_token
,
repo_id
=
"valid_org/test-generation-config-org"
)
# Push to hub via save_pretrained
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
config
.
save_pretrained
(
tmp_dir
,
repo_id
=
"valid_org/test-generation-config-org"
,
push_to_hub
=
True
,
use_auth_token
=
self
.
_token
)
new_config
=
GenerationConfig
.
from_pretrained
(
"valid_org/test-generation-config-org"
)
for
k
,
v
in
config
.
to_dict
().
items
():
if
k
!=
"transformers_version"
:
self
.
assertEqual
(
v
,
getattr
(
new_config
,
k
))
tests/test_modeling_common.py
View file @
1eda4a41
...
...
@@ -63,6 +63,8 @@ from transformers.testing_utils import (
torch_device
,
)
from
transformers.utils
import
(
CONFIG_NAME
,
GENERATION_CONFIG_NAME
,
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
...
...
@@ -275,6 +277,13 @@ class ModelTesterMixin:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
# the config file (and the generation config file, if it can generate) should be saved
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
CONFIG_NAME
)))
self
.
assertEqual
(
model
.
can_generate
(),
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
GENERATION_CONFIG_NAME
))
)
model
=
model_class
.
from_pretrained
(
tmpdirname
)
model
.
to
(
torch_device
)
with
torch
.
no_grad
():
...
...
tests/test_modeling_flax_common.py
View file @
1eda4a41
...
...
@@ -36,7 +36,7 @@ from transformers.testing_utils import (
require_flax
,
torch_device
,
)
from
transformers.utils
import
logging
from
transformers.utils
import
CONFIG_NAME
,
GENERATION_CONFIG_NAME
,
logging
from
transformers.utils.generic
import
ModelOutput
...
...
@@ -395,6 +395,13 @@ class FlaxModelTesterMixin:
# verify that normal save_pretrained works as expected
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
# the config file (and the generation config file, if it can generate) should be saved
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
CONFIG_NAME
)))
self
.
assertEqual
(
model
.
can_generate
(),
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
GENERATION_CONFIG_NAME
))
)
model_loaded
=
model_class
.
from_pretrained
(
tmpdirname
)
outputs_loaded
=
model_loaded
(
**
prepared_inputs_dict
).
to_tuple
()
...
...
tests/test_modeling_tf_common.py
View file @
1eda4a41
...
...
@@ -50,7 +50,14 @@ from transformers.testing_utils import ( # noqa: F401
tooslow
,
torch_device
,
)
from
transformers.utils
import
SAFE_WEIGHTS_NAME
,
TF2_WEIGHTS_INDEX_NAME
,
TF2_WEIGHTS_NAME
,
logging
from
transformers.utils
import
(
CONFIG_NAME
,
GENERATION_CONFIG_NAME
,
SAFE_WEIGHTS_NAME
,
TF2_WEIGHTS_INDEX_NAME
,
TF2_WEIGHTS_NAME
,
logging
,
)
from
transformers.utils.generic
import
ModelOutput
...
...
@@ -226,6 +233,13 @@ class TFModelTesterMixin:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
,
saved_model
=
False
)
# the config file (and the generation config file, if it can generate) should be saved
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
CONFIG_NAME
)))
self
.
assertEqual
(
model
.
can_generate
(),
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
GENERATION_CONFIG_NAME
))
)
model
=
model_class
.
from_pretrained
(
tmpdirname
)
after_outputs
=
model
(
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment