Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
52d2e6f6
Unverified
Commit
52d2e6f6
authored
Feb 11, 2022
by
Sylvain Gugger
Committed by
GitHub
Feb 11, 2022
Browse files
Add push to hub to feature extractor (#15632)
* Add push to hub to feature extractor * Quality * Clean up
parent
4f403ea8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
5 deletions
+69
-5
src/transformers/feature_extraction_utils.py
src/transformers/feature_extraction_utils.py
+32
-3
tests/test_feature_extraction_common.py
tests/test_feature_extraction_common.py
+37
-2
No files found.
src/transformers/feature_extraction_utils.py
View file @
52d2e6f6
...
@@ -30,6 +30,7 @@ from .dynamic_module_utils import custom_object_save
...
@@ -30,6 +30,7 @@ from .dynamic_module_utils import custom_object_save
from
.file_utils
import
(
from
.file_utils
import
(
FEATURE_EXTRACTOR_NAME
,
FEATURE_EXTRACTOR_NAME
,
EntryNotFoundError
,
EntryNotFoundError
,
PushToHubMixin
,
RepositoryNotFoundError
,
RepositoryNotFoundError
,
RevisionNotFoundError
,
RevisionNotFoundError
,
TensorType
,
TensorType
,
...
@@ -37,6 +38,7 @@ from .file_utils import (
...
@@ -37,6 +38,7 @@ from .file_utils import (
_is_numpy
,
_is_numpy
,
_is_torch_device
,
_is_torch_device
,
cached_path
,
cached_path
,
copy_func
,
hf_bucket_url
,
hf_bucket_url
,
is_flax_available
,
is_flax_available
,
is_offline_mode
,
is_offline_mode
,
...
@@ -200,7 +202,7 @@ class BatchFeature(UserDict):
...
@@ -200,7 +202,7 @@ class BatchFeature(UserDict):
return
self
return
self
class
FeatureExtractionMixin
:
class
FeatureExtractionMixin
(
PushToHubMixin
)
:
"""
"""
This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
extractors.
extractors.
...
@@ -308,7 +310,7 @@ class FeatureExtractionMixin:
...
@@ -308,7 +310,7 @@ class FeatureExtractionMixin:
return
cls
.
from_dict
(
feature_extractor_dict
,
**
kwargs
)
return
cls
.
from_dict
(
feature_extractor_dict
,
**
kwargs
)
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]):
def
save_pretrained
(
self
,
save_directory
:
Union
[
str
,
os
.
PathLike
]
,
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
"""
Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
...
@@ -316,10 +318,27 @@ class FeatureExtractionMixin:
...
@@ -316,10 +318,27 @@ class FeatureExtractionMixin:
Args:
Args:
save_directory (`str` or `os.PathLike`):
save_directory (`str` or `os.PathLike`):
Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your feature extractor to the Hugging Face model hub after saving it.
<Tip warning={true}>
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
folder. Pass along `temp_dir=True` to use a temporary directory instead.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
"""
"""
if
os
.
path
.
isfile
(
save_directory
):
if
os
.
path
.
isfile
(
save_directory
):
raise
AssertionError
(
f
"Provided path (
{
save_directory
}
) should be a directory, not a file"
)
raise
AssertionError
(
f
"Provided path (
{
save_directory
}
) should be a directory, not a file"
)
if
push_to_hub
:
commit_message
=
kwargs
.
pop
(
"commit_message"
,
None
)
repo
=
self
.
_create_or_get_repo
(
save_directory
,
**
kwargs
)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
# loaded from the Hub.
if
self
.
_auto_class
is
not
None
:
if
self
.
_auto_class
is
not
None
:
...
@@ -330,7 +349,11 @@ class FeatureExtractionMixin:
...
@@ -330,7 +349,11 @@ class FeatureExtractionMixin:
output_feature_extractor_file
=
os
.
path
.
join
(
save_directory
,
FEATURE_EXTRACTOR_NAME
)
output_feature_extractor_file
=
os
.
path
.
join
(
save_directory
,
FEATURE_EXTRACTOR_NAME
)
self
.
to_json_file
(
output_feature_extractor_file
)
self
.
to_json_file
(
output_feature_extractor_file
)
logger
.
info
(
f
"Configuration saved in
{
output_feature_extractor_file
}
"
)
logger
.
info
(
f
"Feature extractor saved in
{
output_feature_extractor_file
}
"
)
if
push_to_hub
:
url
=
self
.
_push_to_hub
(
repo
,
commit_message
=
commit_message
)
logger
.
info
(
f
"Feature extractor pushed to the hub in this commit:
{
url
}
"
)
@
classmethod
@
classmethod
def
get_feature_extractor_dict
(
def
get_feature_extractor_dict
(
...
@@ -574,3 +597,9 @@ class FeatureExtractionMixin:
...
@@ -574,3 +597,9 @@ class FeatureExtractionMixin:
raise
ValueError
(
f
"
{
auto_class
}
is not a valid auto class."
)
raise
ValueError
(
f
"
{
auto_class
}
is not a valid auto class."
)
cls
.
_auto_class
=
auto_class
cls
.
_auto_class
=
auto_class
FeatureExtractionMixin
.
push_to_hub
=
copy_func
(
FeatureExtractionMixin
.
push_to_hub
)
FeatureExtractionMixin
.
push_to_hub
.
__doc__
=
FeatureExtractionMixin
.
push_to_hub
.
__doc__
.
format
(
object
=
"feature extractor"
,
object_class
=
"AutoFeatureExtractor"
,
object_files
=
"feature extractor file"
)
tests/test_feature_extraction_common.py
View file @
52d2e6f6
...
@@ -23,7 +23,7 @@ from pathlib import Path
...
@@ -23,7 +23,7 @@ from pathlib import Path
from
huggingface_hub
import
Repository
,
delete_repo
,
login
from
huggingface_hub
import
Repository
,
delete_repo
,
login
from
requests.exceptions
import
HTTPError
from
requests.exceptions
import
HTTPError
from
transformers
import
AutoFeatureExtractor
from
transformers
import
AutoFeatureExtractor
,
Wav2Vec2FeatureExtractor
from
transformers.file_utils
import
is_torch_available
,
is_vision_available
from
transformers.file_utils
import
is_torch_available
,
is_vision_available
from
transformers.testing_utils
import
PASS
,
USER
,
is_staging_test
from
transformers.testing_utils
import
PASS
,
USER
,
is_staging_test
...
@@ -40,7 +40,6 @@ if is_torch_available():
...
@@ -40,7 +40,6 @@ if is_torch_available():
if
is_vision_available
():
if
is_vision_available
():
from
PIL
import
Image
from
PIL
import
Image
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures"
)
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures"
)
...
@@ -124,11 +123,47 @@ class ConfigPushToHubTester(unittest.TestCase):
...
@@ -124,11 +123,47 @@ class ConfigPushToHubTester(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
try
:
delete_repo
(
token
=
cls
.
_token
,
name
=
"test-feature-extractor"
)
except
HTTPError
:
pass
try
:
delete_repo
(
token
=
cls
.
_token
,
name
=
"test-feature-extractor-org"
,
organization
=
"valid_org"
)
except
HTTPError
:
pass
try
:
try
:
delete_repo
(
token
=
cls
.
_token
,
name
=
"test-dynamic-feature-extractor"
)
delete_repo
(
token
=
cls
.
_token
,
name
=
"test-dynamic-feature-extractor"
)
except
HTTPError
:
except
HTTPError
:
pass
pass
def
test_push_to_hub
(
self
):
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
feature_extractor
.
save_pretrained
(
os
.
path
.
join
(
tmp_dir
,
"test-feature-extractor"
),
push_to_hub
=
True
,
use_auth_token
=
self
.
_token
)
new_feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
f
"
{
USER
}
/test-feature-extractor"
)
for
k
,
v
in
feature_extractor
.
__dict__
.
items
():
self
.
assertEqual
(
v
,
getattr
(
new_feature_extractor
,
k
))
def
test_push_to_hub_in_organization
(
self
):
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
feature_extractor
.
save_pretrained
(
os
.
path
.
join
(
tmp_dir
,
"test-feature-extractor-org"
),
push_to_hub
=
True
,
use_auth_token
=
self
.
_token
,
organization
=
"valid_org"
,
)
new_feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"valid_org/test-feature-extractor-org"
)
for
k
,
v
in
feature_extractor
.
__dict__
.
items
():
self
.
assertEqual
(
v
,
getattr
(
new_feature_extractor
,
k
))
def
test_push_to_hub_dynamic_feature_extractor
(
self
):
def
test_push_to_hub_dynamic_feature_extractor
(
self
):
CustomFeatureExtractor
.
register_for_auto_class
()
CustomFeatureExtractor
.
register_for_auto_class
()
feature_extractor
=
CustomFeatureExtractor
.
from_pretrained
(
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR
)
feature_extractor
=
CustomFeatureExtractor
.
from_pretrained
(
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment