Unverified Commit 52d2e6f6 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add push to hub to feature extractor (#15632)

* Add push to hub to feature extractor

* Quality

* Clean up
parent 4f403ea8
......@@ -30,6 +30,7 @@ from .dynamic_module_utils import custom_object_save
from .file_utils import (
FEATURE_EXTRACTOR_NAME,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType,
......@@ -37,6 +38,7 @@ from .file_utils import (
_is_numpy,
_is_torch_device,
cached_path,
copy_func,
hf_bucket_url,
is_flax_available,
is_offline_mode,
......@@ -200,7 +202,7 @@ class BatchFeature(UserDict):
return self
class FeatureExtractionMixin:
class FeatureExtractionMixin(PushToHubMixin):
"""
This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
extractors.
......@@ -308,7 +310,7 @@ class FeatureExtractionMixin:
return cls.from_dict(feature_extractor_dict, **kwargs)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
......@@ -316,10 +318,27 @@ class FeatureExtractionMixin:
Args:
save_directory (`str` or `os.PathLike`):
Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your feature extractor to the Hugging Face model hub after saving it.
<Tip warning={true}>
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
folder. Pass along `temp_dir=True` to use a temporary directory instead.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
......@@ -330,7 +349,11 @@ class FeatureExtractionMixin:
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
self.to_json_file(output_feature_extractor_file)
logger.info(f"Configuration saved in {output_feature_extractor_file}")
logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
if push_to_hub:
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Feature extractor pushed to the hub in this commit: {url}")
@classmethod
def get_feature_extractor_dict(
......@@ -574,3 +597,9 @@ class FeatureExtractionMixin:
raise ValueError(f"{auto_class} is not a valid auto class.")
cls._auto_class = auto_class
FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub)
FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format(
object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file"
)
......@@ -23,7 +23,7 @@ from pathlib import Path
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import PASS, USER, is_staging_test
......@@ -40,7 +40,6 @@ if is_torch_available():
if is_vision_available():
from PIL import Image
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
......@@ -124,11 +123,47 @@ class ConfigPushToHubTester(unittest.TestCase):
@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, name="test-feature-extractor")
except HTTPError:
pass
try:
delete_repo(token=cls._token, name="test-feature-extractor-org", organization="valid_org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
except HTTPError:
pass
def test_push_to_hub(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(
os.path.join(tmp_dir, "test-feature-extractor"), push_to_hub=True, use_auth_token=self._token
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_in_organization(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(
os.path.join(tmp_dir, "test-feature-extractor-org"),
push_to_hub=True,
use_auth_token=self._token,
organization="valid_org",
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_dynamic_feature_extractor(self):
CustomFeatureExtractor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment