"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "2791e60093cceddd19f6fec5e9a2637c02ebe3e0"
Unverified Commit 1eda4a41 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: save generation config with the models' `.save_pretrained()` (#21264)

parent cf1a1eed
...@@ -1032,6 +1032,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -1032,6 +1032,8 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
custom_object_save(self, save_directory, config=self.config) custom_object_save(self, save_directory, config=self.config)
self.config.save_pretrained(save_directory) self.config.save_pretrained(save_directory)
if self.can_generate():
self.generation_config.save_pretrained(save_directory)
# save model # save model
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
......
...@@ -2306,6 +2306,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2306,6 +2306,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
custom_object_save(self, save_directory, config=self.config) custom_object_save(self, save_directory, config=self.config)
self.config.save_pretrained(save_directory) self.config.save_pretrained(save_directory)
if self.can_generate():
self.generation_config.save_pretrained(save_directory)
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
......
...@@ -1655,6 +1655,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1655,6 +1655,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Save the config # Save the config
if is_main_process: if is_main_process:
model_to_save.config.save_pretrained(save_directory) model_to_save.config.save_pretrained(save_directory)
if self.can_generate():
model_to_save.generation_config.save_pretrained(save_directory)
# Save the model # Save the model
if state_dict is None: if state_dict is None:
......
...@@ -17,11 +17,14 @@ import copy ...@@ -17,11 +17,14 @@ import copy
import tempfile import tempfile
import unittest import unittest
from huggingface_hub import HfFolder, delete_repo, set_access_token
from parameterized import parameterized from parameterized import parameterized
from requests.exceptions import HTTPError
from transformers import AutoConfig, GenerationConfig from transformers import AutoConfig, GenerationConfig
from transformers.testing_utils import TOKEN, USER, is_staging_test
class LogitsProcessorTest(unittest.TestCase): class GenerationConfigTest(unittest.TestCase):
@parameterized.expand([(None,), ("foo.json",)]) @parameterized.expand([(None,), ("foo.json",)])
def test_save_load_config(self, config_name): def test_save_load_config(self, config_name):
config = GenerationConfig( config = GenerationConfig(
...@@ -74,3 +77,78 @@ class LogitsProcessorTest(unittest.TestCase): ...@@ -74,3 +77,78 @@ class LogitsProcessorTest(unittest.TestCase):
# `.update()` returns a dictionary of unused kwargs # `.update()` returns a dictionary of unused kwargs
self.assertEqual(unused_kwargs, {"foo": "bar"}) self.assertEqual(unused_kwargs, {"foo": "bar"})
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._token = TOKEN
set_access_token(TOKEN)
HfFolder.save_token(TOKEN)
@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, repo_id="test-generation-config")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-generation-config-org")
except HTTPError:
pass
def test_push_to_hub(self):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub("test-generation-config", use_auth_token=self._token)
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
# Reset repo
delete_repo(token=self._token, repo_id="test-generation-config")
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(
tmp_dir, repo_id="test-generation-config", push_to_hub=True, use_auth_token=self._token
)
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_in_organization(self):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub("valid_org/test-generation-config-org", use_auth_token=self._token)
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-generation-config-org")
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(
tmp_dir, repo_id="valid_org/test-generation-config-org", push_to_hub=True, use_auth_token=self._token
)
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
...@@ -63,6 +63,8 @@ from transformers.testing_utils import ( ...@@ -63,6 +63,8 @@ from transformers.testing_utils import (
torch_device, torch_device,
) )
from transformers.utils import ( from transformers.utils import (
CONFIG_NAME,
GENERATION_CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
...@@ -275,6 +277,13 @@ class ModelTesterMixin: ...@@ -275,6 +277,13 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
# the config file (and the generation config file, if it can generate) should be saved
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
self.assertEqual(
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
)
model = model_class.from_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname)
model.to(torch_device) model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
......
...@@ -36,7 +36,7 @@ from transformers.testing_utils import ( ...@@ -36,7 +36,7 @@ from transformers.testing_utils import (
require_flax, require_flax,
torch_device, torch_device,
) )
from transformers.utils import logging from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
from transformers.utils.generic import ModelOutput from transformers.utils.generic import ModelOutput
...@@ -395,6 +395,13 @@ class FlaxModelTesterMixin: ...@@ -395,6 +395,13 @@ class FlaxModelTesterMixin:
# verify that normal save_pretrained works as expected # verify that normal save_pretrained works as expected
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
# the config file (and the generation config file, if it can generate) should be saved
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
self.assertEqual(
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
)
model_loaded = model_class.from_pretrained(tmpdirname) model_loaded = model_class.from_pretrained(tmpdirname)
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple() outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
......
...@@ -50,7 +50,14 @@ from transformers.testing_utils import ( # noqa: F401 ...@@ -50,7 +50,14 @@ from transformers.testing_utils import ( # noqa: F401
tooslow, tooslow,
torch_device, torch_device,
) )
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging from transformers.utils import (
CONFIG_NAME,
GENERATION_CONFIG_NAME,
SAFE_WEIGHTS_NAME,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
logging,
)
from transformers.utils.generic import ModelOutput from transformers.utils.generic import ModelOutput
...@@ -226,6 +233,13 @@ class TFModelTesterMixin: ...@@ -226,6 +233,13 @@ class TFModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=False) model.save_pretrained(tmpdirname, saved_model=False)
# the config file (and the generation config file, if it can generate) should be saved
self.assertTrue(os.path.exists(os.path.join(tmpdirname, CONFIG_NAME)))
self.assertEqual(
model.can_generate(), os.path.exists(os.path.join(tmpdirname, GENERATION_CONFIG_NAME))
)
model = model_class.from_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname)
after_outputs = model(self._prepare_for_class(inputs_dict, model_class)) after_outputs = model(self._prepare_for_class(inputs_dict, model_class))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment