Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2d02f7b2
Unverified
Commit
2d02f7b2
authored
Feb 15, 2022
by
Sylvain Gugger
Committed by
GitHub
Feb 15, 2022
Browse files
Add push_to_hub method to processors (#15668)
* Add push_to_hub method to processors * Fix test * The other one too!
parent
bee361c6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
4 deletions
+73
-4
src/transformers/processing_utils.py
src/transformers/processing_utils.py
+33
-2
tests/test_processor_auto.py
tests/test_processor_auto.py
+40
-2
No files found.
src/transformers/processing_utils.py
View file @
2d02f7b2
...
...
@@ -21,9 +21,13 @@ import os
from
pathlib
import
Path
from
.dynamic_module_utils
import
custom_object_save
from
.file_utils
import
PushToHubMixin
,
copy_func
from
.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
spec
=
importlib
.
util
.
spec_from_file_location
(
"transformers"
,
Path
(
__file__
).
parent
/
"__init__.py"
,
submodule_search_locations
=
[
Path
(
__file__
).
parent
]
...
...
@@ -37,7 +41,7 @@ AUTO_TO_BASE_CLASS_MAPPING = {
}
class
ProcessorMixin
:
class
ProcessorMixin
(
PushToHubMixin
)
:
"""
This is a mixin used to provide saving/loading functionality for all processor classes.
"""
...
...
@@ -88,7 +92,7 @@ class ProcessorMixin:
attributes_repr
=
"
\n
"
.
join
(
attributes_repr
)
return
f
"
{
self
.
__class__
.
__name__
}
:
\n
{
attributes_repr
}
"
def
save_pretrained
(
self
,
save_directory
):
def
save_pretrained
(
self
,
save_directory
,
push_to_hub
:
bool
=
False
,
**
kwargs
):
"""
Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.
...
...
@@ -105,7 +109,24 @@ class ProcessorMixin:
save_directory (`str` or `os.PathLike`):
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your processor to the Hugging Face model hub after saving it.
<Tip warning={true}>
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
folder. Pass along `temp_dir=True` to use a temporary directory instead.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
"""
if
push_to_hub
:
commit_message
=
kwargs
.
pop
(
"commit_message"
,
None
)
repo
=
self
.
_create_or_get_repo
(
save_directory
,
**
kwargs
)
os
.
makedirs
(
save_directory
,
exist_ok
=
True
)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
...
...
@@ -129,6 +150,10 @@ class ProcessorMixin:
if
isinstance
(
attribute
,
PreTrainedTokenizerBase
):
del
attribute
.
init_kwargs
[
"auto_map"
]
if
push_to_hub
:
url
=
self
.
_push_to_hub
(
repo
,
commit_message
=
commit_message
)
logger
.
info
(
f
"Processor pushed to the hub in this commit:
{
url
}
"
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
**
kwargs
):
r
"""
...
...
@@ -205,3 +230,9 @@ class ProcessorMixin:
args
.
append
(
attribute_class
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
))
return
args
ProcessorMixin
.
push_to_hub
=
copy_func
(
ProcessorMixin
.
push_to_hub
)
ProcessorMixin
.
push_to_hub
.
__doc__
=
ProcessorMixin
.
push_to_hub
.
__doc__
.
format
(
object
=
"processor"
,
object_class
=
"AutoProcessor"
,
object_files
=
"processor files"
)
tests/test_processor_auto.py
View file @
2d02f7b2
...
...
@@ -41,7 +41,7 @@ SAMPLE_PROCESSOR_CONFIG = os.path.join(
)
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/vocab.json"
)
SAMPLE_
FEATURE_EXTRACTION
_CONFIG_DIR
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures"
)
SAMPLE_
PROCESSOR
_CONFIG_DIR
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures"
)
class
AutoFeatureExtractorTest
(
unittest
.
TestCase
):
...
...
@@ -165,17 +165,55 @@ class ProcessorPushToHubTester(unittest.TestCase):
@
classmethod
def
tearDownClass
(
cls
):
try
:
delete_repo
(
token
=
cls
.
_token
,
name
=
"test-processor"
)
except
HTTPError
:
pass
try
:
delete_repo
(
token
=
cls
.
_token
,
name
=
"test-processor-org"
,
organization
=
"valid_org"
)
except
HTTPError
:
pass
try
:
delete_repo
(
token
=
cls
.
_token
,
name
=
"test-dynamic-processor"
)
except
HTTPError
:
pass
def
test_push_to_hub
(
self
):
processor
=
Wav2Vec2Processor
.
from_pretrained
(
SAMPLE_PROCESSOR_CONFIG_DIR
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
processor
.
save_pretrained
(
os
.
path
.
join
(
tmp_dir
,
"test-processor"
),
push_to_hub
=
True
,
use_auth_token
=
self
.
_token
)
new_processor
=
Wav2Vec2Processor
.
from_pretrained
(
f
"
{
USER
}
/test-processor"
)
for
k
,
v
in
processor
.
feature_extractor
.
__dict__
.
items
():
self
.
assertEqual
(
v
,
getattr
(
new_processor
.
feature_extractor
,
k
))
self
.
assertDictEqual
(
new_processor
.
tokenizer
.
get_vocab
(),
processor
.
tokenizer
.
get_vocab
())
def
test_push_to_hub_in_organization
(
self
):
processor
=
Wav2Vec2Processor
.
from_pretrained
(
SAMPLE_PROCESSOR_CONFIG_DIR
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
processor
.
save_pretrained
(
os
.
path
.
join
(
tmp_dir
,
"test-processor-org"
),
push_to_hub
=
True
,
use_auth_token
=
self
.
_token
,
organization
=
"valid_org"
,
)
new_processor
=
Wav2Vec2Processor
.
from_pretrained
(
"valid_org/test-processor-org"
)
for
k
,
v
in
processor
.
feature_extractor
.
__dict__
.
items
():
self
.
assertEqual
(
v
,
getattr
(
new_processor
.
feature_extractor
,
k
))
self
.
assertDictEqual
(
new_processor
.
tokenizer
.
get_vocab
(),
processor
.
tokenizer
.
get_vocab
())
def
test_push_to_hub_dynamic_processor
(
self
):
CustomFeatureExtractor
.
register_for_auto_class
()
CustomTokenizer
.
register_for_auto_class
()
CustomProcessor
.
register_for_auto_class
()
feature_extractor
=
CustomFeatureExtractor
.
from_pretrained
(
SAMPLE_
FEATURE_EXTRACTION
_CONFIG_DIR
)
feature_extractor
=
CustomFeatureExtractor
.
from_pretrained
(
SAMPLE_
PROCESSOR
_CONFIG_DIR
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
vocab_file
=
os
.
path
.
join
(
tmp_dir
,
"vocab.txt"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment