"examples/vscode:/vscode.git/clone" did not exist on "e84adbed408ddf77e30ac5bd437ecc1f7f3f708a"
Unverified Commit 2d02f7b2 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add push_to_hub method to processors (#15668)

* Add push_to_hub method to processors

* Fix test

* The other one too!
parent bee361c6
...@@ -21,9 +21,13 @@ import os ...@@ -21,9 +21,13 @@ import os
from pathlib import Path from pathlib import Path
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .file_utils import PushToHubMixin, copy_func
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .utils import logging
logger = logging.get_logger(__name__)
# Dynamically import the Transformers module to grab the attribute classes of the processor form their names. # Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
"transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent] "transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent]
...@@ -37,7 +41,7 @@ AUTO_TO_BASE_CLASS_MAPPING = { ...@@ -37,7 +41,7 @@ AUTO_TO_BASE_CLASS_MAPPING = {
} }
class ProcessorMixin: class ProcessorMixin(PushToHubMixin):
""" """
This is a mixin used to provide saving/loading functionality for all processor classes. This is a mixin used to provide saving/loading functionality for all processor classes.
""" """
...@@ -88,7 +92,7 @@ class ProcessorMixin: ...@@ -88,7 +92,7 @@ class ProcessorMixin:
attributes_repr = "\n".join(attributes_repr) attributes_repr = "\n".join(attributes_repr)
return f"{self.__class__.__name__}:\n{attributes_repr}" return f"{self.__class__.__name__}:\n{attributes_repr}"
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
""" """
Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
can be reloaded using the [`~ProcessorMixin.from_pretrained`] method. can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.
...@@ -105,7 +109,24 @@ class ProcessorMixin: ...@@ -105,7 +109,24 @@ class ProcessorMixin:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
be created if it does not exist). be created if it does not exist).
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your processor to the Hugging Face model hub after saving it.
<Tip warning={true}>
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
folder. Pass along `temp_dir=True` to use a temporary directory instead.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
""" """
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub. # loaded from the Hub.
...@@ -129,6 +150,10 @@ class ProcessorMixin: ...@@ -129,6 +150,10 @@ class ProcessorMixin:
if isinstance(attribute, PreTrainedTokenizerBase): if isinstance(attribute, PreTrainedTokenizerBase):
del attribute.init_kwargs["auto_map"] del attribute.init_kwargs["auto_map"]
if push_to_hub:
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Processor pushed to the hub in this commit: {url}")
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" r"""
...@@ -205,3 +230,9 @@ class ProcessorMixin: ...@@ -205,3 +230,9 @@ class ProcessorMixin:
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs)) args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
return args return args
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
object="processor", object_class="AutoProcessor", object_files="processor files"
)
...@@ -41,7 +41,7 @@ SAMPLE_PROCESSOR_CONFIG = os.path.join( ...@@ -41,7 +41,7 @@ SAMPLE_PROCESSOR_CONFIG = os.path.join(
) )
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json") SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
class AutoFeatureExtractorTest(unittest.TestCase): class AutoFeatureExtractorTest(unittest.TestCase):
...@@ -165,17 +165,55 @@ class ProcessorPushToHubTester(unittest.TestCase): ...@@ -165,17 +165,55 @@ class ProcessorPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try:
delete_repo(token=cls._token, name="test-processor")
except HTTPError:
pass
try:
delete_repo(token=cls._token, name="test-processor-org", organization="valid_org")
except HTTPError:
pass
try: try:
delete_repo(token=cls._token, name="test-dynamic-processor") delete_repo(token=cls._token, name="test-dynamic-processor")
except HTTPError: except HTTPError:
pass pass
def test_push_to_hub(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(
os.path.join(tmp_dir, "test-processor"), push_to_hub=True, use_auth_token=self._token
)
new_processor = Wav2Vec2Processor.from_pretrained(f"{USER}/test-processor")
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
def test_push_to_hub_in_organization(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(
os.path.join(tmp_dir, "test-processor-org"),
push_to_hub=True,
use_auth_token=self._token,
organization="valid_org",
)
new_processor = Wav2Vec2Processor.from_pretrained("valid_org/test-processor-org")
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
def test_push_to_hub_dynamic_processor(self): def test_push_to_hub_dynamic_processor(self):
CustomFeatureExtractor.register_for_auto_class() CustomFeatureExtractor.register_for_auto_class()
CustomTokenizer.register_for_auto_class() CustomTokenizer.register_for_auto_class()
CustomProcessor.register_for_auto_class() CustomProcessor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") vocab_file = os.path.join(tmp_dir, "vocab.txt")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment