Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1eda4a41
Unverified
Commit
1eda4a41
authored
Jan 23, 2023
by
Joao Gante
Committed by
GitHub
Jan 23, 2023
Browse files
Generate: save generation config with the models' `.save_pretrained()` (#21264)
parent
cf1a1eed
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
117 additions
and
3 deletions
+117
-3
src/transformers/modeling_flax_utils.py
src/transformers/modeling_flax_utils.py
+2
-0
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+2
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+2
-0
tests/generation/test_configuration_utils.py
tests/generation/test_configuration_utils.py
+79
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+9
-0
tests/test_modeling_flax_common.py
tests/test_modeling_flax_common.py
+8
-1
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+15
-1
No files found.
src/transformers/modeling_flax_utils.py
View file @
1eda4a41
...
...
@@ -1032,6 +1032,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
custom_object_save
(
self
,
save_directory
,
config
=
self
.
config
)
self
.
config
.
save_pretrained
(
save_directory
)
if
self
.
can_generate
():
self
.
generation_config
.
save_pretrained
(
save_directory
)
# save model
output_model_file
=
os
.
path
.
join
(
save_directory
,
FLAX_WEIGHTS_NAME
)
...
...
src/transformers/modeling_tf_utils.py
View file @
1eda4a41
...
...
@@ -2306,6 +2306,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
custom_object_save
(
self
,
save_directory
,
config
=
self
.
config
)
self
.
config
.
save_pretrained
(
save_directory
)
if
self
.
can_generate
():
self
.
generation_config
.
save_pretrained
(
save_directory
)
# If we save using the predefined names, we can load using `from_pretrained`
weights_name
=
SAFE_WEIGHTS_NAME
if
safe_serialization
else
TF2_WEIGHTS_NAME
...
...
src/transformers/modeling_utils.py
View file @
1eda4a41
...
...
@@ -1655,6 +1655,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Save the config
if
is_main_process
:
model_to_save
.
config
.
save_pretrained
(
save_directory
)
if
self
.
can_generate
():
model_to_save
.
generation_config
.
save_pretrained
(
save_directory
)
# Save the model
if
state_dict
is
None
:
...
...
tests/generation/test_configuration_utils.py
View file @
1eda4a41
...
...
@@ -17,11 +17,14 @@ import copy
import
tempfile
import
unittest
from
huggingface_hub
import
HfFolder
,
delete_repo
,
set_access_token
from
parameterized
import
parameterized
from
requests.exceptions
import
HTTPError
from
transformers
import
AutoConfig
,
GenerationConfig
from
transformers.testing_utils
import
TOKEN
,
USER
,
is_staging_test
class
LogitsProcessor
Test
(
unittest
.
TestCase
):
class
GenerationConfig
Test
(
unittest
.
TestCase
):
@
parameterized
.
expand
([(
None
,),
(
"foo.json"
,)])
def
test_save_load_config
(
self
,
config_name
):
config
=
GenerationConfig
(
...
...
@@ -74,3 +77,78 @@ class LogitsProcessorTest(unittest.TestCase):
# `.update()` returns a dictionary of unused kwargs
self
.
assertEqual
(
unused_kwargs
,
{
"foo"
:
"bar"
})
@
is_staging_test
class
ConfigPushToHubTester
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
_token
=
TOKEN
set_access_token
(
TOKEN
)
HfFolder
.
save_token
(
TOKEN
)
@
classmethod
def
tearDownClass
(
cls
):
try
:
delete_repo
(
token
=
cls
.
_token
,
repo_id
=
"test-generation-config"
)
except
HTTPError
:
pass
try
:
delete_repo
(
token
=
cls
.
_token
,
repo_id
=
"valid_org/test-generation-config-org"
)
except
HTTPError
:
pass
def
test_push_to_hub
(
self
):
config
=
GenerationConfig
(
do_sample
=
True
,
temperature
=
0.7
,
length_penalty
=
1.0
,
)
config
.
push_to_hub
(
"test-generation-config"
,
use_auth_token
=
self
.
_token
)
new_config
=
GenerationConfig
.
from_pretrained
(
f
"
{
USER
}
/test-generation-config"
)
for
k
,
v
in
config
.
to_dict
().
items
():
if
k
!=
"transformers_version"
:
self
.
assertEqual
(
v
,
getattr
(
new_config
,
k
))
# Reset repo
delete_repo
(
token
=
self
.
_token
,
repo_id
=
"test-generation-config"
)
# Push to hub via save_pretrained
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
config
.
save_pretrained
(
tmp_dir
,
repo_id
=
"test-generation-config"
,
push_to_hub
=
True
,
use_auth_token
=
self
.
_token
)
new_config
=
GenerationConfig
.
from_pretrained
(
f
"
{
USER
}
/test-generation-config"
)
for
k
,
v
in
config
.
to_dict
().
items
():
if
k
!=
"transformers_version"
:
self
.
assertEqual
(
v
,
getattr
(
new_config
,
k
))
def
test_push_to_hub_in_organization
(
self
):
config
=
GenerationConfig
(
do_sample
=
True
,
temperature
=
0.7
,
length_penalty
=
1.0
,
)
config
.
push_to_hub
(
"valid_org/test-generation-config-org"
,
use_auth_token
=
self
.
_token
)
new_config
=
GenerationConfig
.
from_pretrained
(
"valid_org/test-generation-config-org"
)
for
k
,
v
in
config
.
to_dict
().
items
():
if
k
!=
"transformers_version"
:
self
.
assertEqual
(
v
,
getattr
(
new_config
,
k
))
# Reset repo
delete_repo
(
token
=
self
.
_token
,
repo_id
=
"valid_org/test-generation-config-org"
)
# Push to hub via save_pretrained
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
config
.
save_pretrained
(
tmp_dir
,
repo_id
=
"valid_org/test-generation-config-org"
,
push_to_hub
=
True
,
use_auth_token
=
self
.
_token
)
new_config
=
GenerationConfig
.
from_pretrained
(
"valid_org/test-generation-config-org"
)
for
k
,
v
in
config
.
to_dict
().
items
():
if
k
!=
"transformers_version"
:
self
.
assertEqual
(
v
,
getattr
(
new_config
,
k
))
tests/test_modeling_common.py
View file @
1eda4a41
...
...
@@ -63,6 +63,8 @@ from transformers.testing_utils import (
torch_device
,
)
from
transformers.utils
import
(
CONFIG_NAME
,
GENERATION_CONFIG_NAME
,
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
...
...
@@ -275,6 +277,13 @@ class ModelTesterMixin:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
# the config file (and the generation config file, if it can generate) should be saved
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
CONFIG_NAME
)))
self
.
assertEqual
(
model
.
can_generate
(),
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
GENERATION_CONFIG_NAME
))
)
model
=
model_class
.
from_pretrained
(
tmpdirname
)
model
.
to
(
torch_device
)
with
torch
.
no_grad
():
...
...
tests/test_modeling_flax_common.py
View file @
1eda4a41
...
...
@@ -36,7 +36,7 @@ from transformers.testing_utils import (
require_flax
,
torch_device
,
)
from
transformers.utils
import
logging
from
transformers.utils
import
CONFIG_NAME
,
GENERATION_CONFIG_NAME
,
logging
from
transformers.utils.generic
import
ModelOutput
...
...
@@ -395,6 +395,13 @@ class FlaxModelTesterMixin:
# verify that normal save_pretrained works as expected
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
# the config file (and the generation config file, if it can generate) should be saved
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
CONFIG_NAME
)))
self
.
assertEqual
(
model
.
can_generate
(),
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
GENERATION_CONFIG_NAME
))
)
model_loaded
=
model_class
.
from_pretrained
(
tmpdirname
)
outputs_loaded
=
model_loaded
(
**
prepared_inputs_dict
).
to_tuple
()
...
...
tests/test_modeling_tf_common.py
View file @
1eda4a41
...
...
@@ -50,7 +50,14 @@ from transformers.testing_utils import ( # noqa: F401
tooslow
,
torch_device
,
)
from
transformers.utils
import
SAFE_WEIGHTS_NAME
,
TF2_WEIGHTS_INDEX_NAME
,
TF2_WEIGHTS_NAME
,
logging
from
transformers.utils
import
(
CONFIG_NAME
,
GENERATION_CONFIG_NAME
,
SAFE_WEIGHTS_NAME
,
TF2_WEIGHTS_INDEX_NAME
,
TF2_WEIGHTS_NAME
,
logging
,
)
from
transformers.utils.generic
import
ModelOutput
...
...
@@ -226,6 +233,13 @@ class TFModelTesterMixin:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
,
saved_model
=
False
)
# the config file (and the generation config file, if it can generate) should be saved
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
CONFIG_NAME
)))
self
.
assertEqual
(
model
.
can_generate
(),
os
.
path
.
exists
(
os
.
path
.
join
(
tmpdirname
,
GENERATION_CONFIG_NAME
))
)
model
=
model_class
.
from_pretrained
(
tmpdirname
)
after_outputs
=
model
(
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment