Unverified Commit 2d02f7b2 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add push_to_hub method to processors (#15668)

* Add push_to_hub method to processors

* Fix test

* The other one too!
parent bee361c6
......@@ -21,9 +21,13 @@ import os
from pathlib import Path
from .dynamic_module_utils import custom_object_save
from .file_utils import PushToHubMixin, copy_func
from .tokenization_utils_base import PreTrainedTokenizerBase
from .utils import logging
logger = logging.get_logger(__name__)
# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
spec = importlib.util.spec_from_file_location(
"transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent]
......@@ -37,7 +41,7 @@ AUTO_TO_BASE_CLASS_MAPPING = {
}
class ProcessorMixin:
class ProcessorMixin(PushToHubMixin):
"""
This is a mixin used to provide saving/loading functionality for all processor classes.
"""
......@@ -88,7 +92,7 @@ class ProcessorMixin:
attributes_repr = "\n".join(attributes_repr)
return f"{self.__class__.__name__}:\n{attributes_repr}"
def save_pretrained(self, save_directory):
def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
"""
Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.
......@@ -105,7 +109,24 @@ class ProcessorMixin:
save_directory (`str` or `os.PathLike`):
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your processor to the Hugging Face model hub after saving it.
<Tip warning={true}>
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
folder. Pass along `temp_dir=True` to use a temporary directory instead.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
"""
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
......@@ -129,6 +150,10 @@ class ProcessorMixin:
if isinstance(attribute, PreTrainedTokenizerBase):
del attribute.init_kwargs["auto_map"]
if push_to_hub:
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Processor pushed to the hub in this commit: {url}")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
......@@ -205,3 +230,9 @@ class ProcessorMixin:
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
return args
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
object="processor", object_class="AutoProcessor", object_files="processor files"
)
......@@ -41,7 +41,7 @@ SAMPLE_PROCESSOR_CONFIG = os.path.join(
)
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
class AutoFeatureExtractorTest(unittest.TestCase):
......@@ -165,17 +165,55 @@ class ProcessorPushToHubTester(unittest.TestCase):
@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, name="test-processor")
except HTTPError:
pass
try:
delete_repo(token=cls._token, name="test-processor-org", organization="valid_org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, name="test-dynamic-processor")
except HTTPError:
pass
def test_push_to_hub(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(
os.path.join(tmp_dir, "test-processor"), push_to_hub=True, use_auth_token=self._token
)
new_processor = Wav2Vec2Processor.from_pretrained(f"{USER}/test-processor")
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
def test_push_to_hub_in_organization(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(
os.path.join(tmp_dir, "test-processor-org"),
push_to_hub=True,
use_auth_token=self._token,
organization="valid_org",
)
new_processor = Wav2Vec2Processor.from_pretrained("valid_org/test-processor-org")
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
def test_push_to_hub_dynamic_processor(self):
CustomFeatureExtractor.register_for_auto_class()
CustomTokenizer.register_for_auto_class()
CustomProcessor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment