Unverified Commit 52d2e6f6 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add push to hub to feature extractor (#15632)

* Add push to hub to feature extractor

* Quality

* Clean up
parent 4f403ea8
...@@ -30,6 +30,7 @@ from .dynamic_module_utils import custom_object_save ...@@ -30,6 +30,7 @@ from .dynamic_module_utils import custom_object_save
from .file_utils import ( from .file_utils import (
FEATURE_EXTRACTOR_NAME, FEATURE_EXTRACTOR_NAME,
EntryNotFoundError, EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
TensorType, TensorType,
...@@ -37,6 +38,7 @@ from .file_utils import ( ...@@ -37,6 +38,7 @@ from .file_utils import (
_is_numpy, _is_numpy,
_is_torch_device, _is_torch_device,
cached_path, cached_path,
copy_func,
hf_bucket_url, hf_bucket_url,
is_flax_available, is_flax_available,
is_offline_mode, is_offline_mode,
...@@ -200,7 +202,7 @@ class BatchFeature(UserDict): ...@@ -200,7 +202,7 @@ class BatchFeature(UserDict):
return self return self
class FeatureExtractionMixin: class FeatureExtractionMixin(PushToHubMixin):
""" """
This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
extractors. extractors.
...@@ -308,7 +310,7 @@ class FeatureExtractionMixin: ...@@ -308,7 +310,7 @@ class FeatureExtractionMixin:
return cls.from_dict(feature_extractor_dict, **kwargs) return cls.from_dict(feature_extractor_dict, **kwargs)
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
""" """
Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method. [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
...@@ -316,10 +318,27 @@ class FeatureExtractionMixin: ...@@ -316,10 +318,27 @@ class FeatureExtractionMixin:
Args: Args:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory where the feature extractor JSON file will be saved (will be created if it does not exist). Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your feature extractor to the Hugging Face model hub after saving it.
<Tip warning={true}>
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
folder. Pass along `temp_dir=True` to use a temporary directory instead.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub. # loaded from the Hub.
if self._auto_class is not None: if self._auto_class is not None:
...@@ -330,7 +349,11 @@ class FeatureExtractionMixin: ...@@ -330,7 +349,11 @@ class FeatureExtractionMixin:
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME) output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
self.to_json_file(output_feature_extractor_file) self.to_json_file(output_feature_extractor_file)
logger.info(f"Configuration saved in {output_feature_extractor_file}") logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
if push_to_hub:
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Feature extractor pushed to the hub in this commit: {url}")
@classmethod @classmethod
def get_feature_extractor_dict( def get_feature_extractor_dict(
...@@ -574,3 +597,9 @@ class FeatureExtractionMixin: ...@@ -574,3 +597,9 @@ class FeatureExtractionMixin:
raise ValueError(f"{auto_class} is not a valid auto class.") raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class cls._auto_class = auto_class
FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub)
FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format(
object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file"
)
...@@ -23,7 +23,7 @@ from pathlib import Path ...@@ -23,7 +23,7 @@ from pathlib import Path
from huggingface_hub import Repository, delete_repo, login from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
from transformers.file_utils import is_torch_available, is_vision_available from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import PASS, USER, is_staging_test from transformers.testing_utils import PASS, USER, is_staging_test
...@@ -40,7 +40,6 @@ if is_torch_available(): ...@@ -40,7 +40,6 @@ if is_torch_available():
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
...@@ -124,11 +123,47 @@ class ConfigPushToHubTester(unittest.TestCase): ...@@ -124,11 +123,47 @@ class ConfigPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try:
delete_repo(token=cls._token, name="test-feature-extractor")
except HTTPError:
pass
try:
delete_repo(token=cls._token, name="test-feature-extractor-org", organization="valid_org")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, name="test-dynamic-feature-extractor") delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
except HTTPError: except HTTPError:
pass pass
def test_push_to_hub(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(
os.path.join(tmp_dir, "test-feature-extractor"), push_to_hub=True, use_auth_token=self._token
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_in_organization(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(
os.path.join(tmp_dir, "test-feature-extractor-org"),
push_to_hub=True,
use_auth_token=self._token,
organization="valid_org",
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_dynamic_feature_extractor(self): def test_push_to_hub_dynamic_feature_extractor(self):
CustomFeatureExtractor.register_for_auto_class() CustomFeatureExtractor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment