Unverified Commit bf2e0cf7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Trainer push to hub (#11328)



* Initial support for upload to hub

* push -> upload

* Fixes + examples

* Fix torchhub test

* Torchhub test I hate you

* push_model_to_hub -> push_to_hub

* Apply mixin to other pretrained models

* Remove ABC inheritance

* Add tests

* Typo

* Run tests

* Install git-lfs

* Change approach

* Add push_to_hub to all

* Staging test suite

* Typo

* Maybe like this?

* More deps

* Cache

* Adapt name

* Quality

* MOAR tests

* Put it in testing_utils

* Docs + torchhub last hope

* Styling

* Wrong method

* Typos

* Update src/transformers/file_utils.py
Co-authored-by: default avatarJulien Chaumond <julien@huggingface.co>

* Address review comments

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarJulien Chaumond <julien@huggingface.co>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7bc86bea
...@@ -50,6 +50,11 @@ DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown" ...@@ -50,6 +50,11 @@ DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
# Used to test Auto{Config, Model, Tokenizer} model_type detection. # Used to test Auto{Config, Model, Tokenizer} model_type detection.
# Used to test the hub
USER = "__DUMMY_TRANSFORMERS_USER__"
PASS = "__DUMMY_TRANSFORMERS_PASS__"
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
def parse_flag_from_env(key, default=False): def parse_flag_from_env(key, default=False):
try: try:
...@@ -84,6 +89,7 @@ _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) ...@@ -84,6 +89,7 @@ _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False) _run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False)
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=False) _run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=False)
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False) _run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False)
_run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False) _run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False)
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None) _tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
...@@ -146,6 +152,23 @@ def is_pipeline_test(test_case): ...@@ -146,6 +152,23 @@ def is_pipeline_test(test_case):
return pytest.mark.is_pipeline_test()(test_case) return pytest.mark.is_pipeline_test()(test_case)
def is_staging_test(test_case):
"""
Decorator marking a test as a staging test.
Those tests will run using the staging environment of huggingface.co instead of the real model hub.
"""
if not _run_staging:
return unittest.skip("test is staging test")(test_case)
else:
try:
import pytest # We don't need a hard dependency on pytest in the main library
except ImportError:
return test_case
else:
return pytest.mark.is_staging_test()(test_case)
def slow(test_case): def slow(test_case):
""" """
Decorator marking a test as slow. Decorator marking a test as slow.
......
...@@ -34,6 +34,7 @@ import requests ...@@ -34,6 +34,7 @@ import requests
from .file_utils import ( from .file_utils import (
ExplicitEnum, ExplicitEnum,
PaddingStrategy, PaddingStrategy,
PushToHubMixin,
TensorType, TensorType,
_is_jax, _is_jax,
_is_numpy, _is_numpy,
...@@ -1415,7 +1416,7 @@ INIT_TOKENIZER_DOCSTRING = r""" ...@@ -1415,7 +1416,7 @@ INIT_TOKENIZER_DOCSTRING = r"""
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) @add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
class PreTrainedTokenizerBase(SpecialTokensMixin): class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
""" """
Base class for :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast`. Base class for :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast`.
...@@ -1850,6 +1851,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1850,6 +1851,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
legacy_format: Optional[bool] = None, legacy_format: Optional[bool] = None,
filename_prefix: Optional[str] = None, filename_prefix: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
) -> Tuple[str]: ) -> Tuple[str]:
""" """
Save the full tokenizer state. Save the full tokenizer state.
...@@ -1925,13 +1928,21 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1925,13 +1928,21 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
file_names = (tokenizer_config_file, special_tokens_map_file) file_names = (tokenizer_config_file, special_tokens_map_file)
return self._save_pretrained( save_files = self._save_pretrained(
save_directory=save_directory, save_directory=save_directory,
file_names=file_names, file_names=file_names,
legacy_format=legacy_format, legacy_format=legacy_format,
filename_prefix=filename_prefix, filename_prefix=filename_prefix,
) )
if push_to_hub:
# Annoyingly, the return contains files that don't exist.
existing_files = [f for f in save_files if os.path.isfile(f)]
url = self._push_to_hub(save_files=existing_files, **kwargs)
logger.info(f"Tokenizer pushed to the hub in this commit: {url}")
return save_files
def _save_pretrained( def _save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
......
...@@ -23,6 +23,7 @@ import os ...@@ -23,6 +23,7 @@ import os
import re import re
import shutil import shutil
import sys import sys
import tempfile
import time import time
import warnings import warnings
from logging import StreamHandler from logging import StreamHandler
...@@ -62,6 +63,7 @@ from .dependency_versions_check import dep_version_check ...@@ -62,6 +63,7 @@ from .dependency_versions_check import dep_version_check
from .file_utils import ( from .file_utils import (
CONFIG_NAME, CONFIG_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
PushToHubMixin,
is_apex_available, is_apex_available,
is_datasets_available, is_datasets_available,
is_in_notebook, is_in_notebook,
...@@ -2274,6 +2276,71 @@ class Trainer: ...@@ -2274,6 +2276,71 @@ class Trainer:
else: else:
return 0 return 0
def push_to_hub(
self,
save_directory: Optional[str] = None,
repo_name: Optional[str] = None,
repo_url: Optional[str] = None,
commit_message: Optional[str] = "add model",
organization: Optional[str] = None,
private: bool = None,
use_auth_token: Optional[Union[bool, str]] = None,
):
"""
Upload `self.model` to the 🤗 model hub.
Parameters:
save_directory (:obj:`str` or :obj:`os.PathLike`):
Folder containing the model weights and config. Will default to :obj:`self.args.output_dir`.
repo_name (:obj:`str`, `optional`):
Repository name for your model or tokenizer in the hub. If not specified, the repository name will be
the stem of :obj:`save_directory`.
repo_url (:obj:`str`, `optional`):
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
repository will be created in your namespace (unless you specify an :obj:`organization`) with
:obj:`repo_name`.
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
Message to commit while pushing.
organization (:obj:`str`, `optional`):
Organization in which you want to push your model or tokenizer (you must be a member of this
organization).
private (:obj:`bool`, `optional`):
Whether or not the repository created should be private (requires a paying subscription).
use_auth_token (:obj:`bool` or :obj:`str`, `optional`):
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). Will default to
:obj:`True` if :obj:`repo_url` is not specified.
Returns:
The url of the commit of your model in the given repository.
"""
if not self.is_world_process_zero():
return
if not isinstance(unwrap_model(self.model), PushToHubMixin):
raise ValueError(
"The `upload_model_to_hub` method only works for models that inherit from `PushToHubMixin` models."
)
if save_directory is None:
save_directory = self.args.output_dir
# To avoid pushing all checkpoints, we just copy all the files in save_directory in a tmp dir.
with tempfile.TemporaryDirectory() as tmp_dir:
for f in os.listdir(save_directory):
fname = os.path.join(save_directory, f)
if os.path.isfile(fname):
shutil.copy(fname, os.path.join(tmp_dir, f))
return unwrap_model(self.model)._push_to_hub(
save_directory=tmp_dir,
repo_name=repo_name,
repo_url=repo_url,
commit_message=commit_message,
organization=organization,
private=private,
use_auth_token=use_auth_token,
)
# #
# Deprecated code # Deprecated code
# #
......
...@@ -295,11 +295,15 @@ class TrainingArguments: ...@@ -295,11 +295,15 @@ class TrainingArguments:
When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to When using distributed training, the value of the flag :obj:`find_unused_parameters` passed to
:obj:`DistributedDataParallel`. Will default to :obj:`False` if gradient checkpointing is used, :obj:`True` :obj:`DistributedDataParallel`. Will default to :obj:`False` if gradient checkpointing is used, :obj:`True`
otherwise. otherwise.
dataloader_pin_memory (:obj:`bool`, `optional`, defaults to :obj:`True`)): dataloader_pin_memory (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether you want to pin memory in data loaders or not. Will default to :obj:`True`. Whether you want to pin memory in data loaders or not. Will default to :obj:`True`.
skip_memory_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`)): skip_memory_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to skip adding of memory profiler reports to metrics. Defaults to :obj:`False`. Whether to skip adding of memory profiler reports to metrics. Defaults to :obj:`False`.
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to upload the trained model to the hub after training. This argument is not directly used by
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -527,6 +531,9 @@ class TrainingArguments: ...@@ -527,6 +531,9 @@ class TrainingArguments:
use_legacy_prediction_loop: bool = field( use_legacy_prediction_loop: bool = field(
default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."} default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."}
) )
push_to_hub: bool = field(
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
)
_n_gpu: int = field(init=False, repr=False, default=-1) _n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field( mp_parameters: str = field(
default="", default="",
......
...@@ -38,6 +38,7 @@ def pytest_configure(config): ...@@ -38,6 +38,7 @@ def pytest_configure(config):
config.addinivalue_line( config.addinivalue_line(
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested" "markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
) )
config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
def pytest_addoption(parser): def pytest_addoption(parser):
......
...@@ -17,6 +17,12 @@ ...@@ -17,6 +17,12 @@
import json import json
import os import os
import tempfile import tempfile
import unittest
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import BertConfig
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
class ConfigTester(object): class ConfigTester(object):
...@@ -81,3 +87,54 @@ class ConfigTester(object): ...@@ -81,3 +87,54 @@ class ConfigTester(object):
self.create_and_test_config_from_and_save_pretrained() self.create_and_test_config_from_and_save_pretrained()
self.create_and_test_config_with_num_labels() self.create_and_test_config_with_num_labels()
self.check_config_can_be_init_without_params() self.check_config_can_be_init_without_params()
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
try:
cls._api.delete_repo(token=cls._token, name="test-model")
except HTTPError:
pass
try:
cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
except HTTPError:
pass
def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-model")
for k, v in config.__dict__.items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(
tmp_dir,
push_to_hub=True,
repo_name="test-model-org",
use_auth_token=self._token,
organization="valid_org",
)
new_config = BertConfig.from_pretrained("valid_org/test-model-org")
for k, v in config.__dict__.items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
...@@ -22,13 +22,9 @@ import unittest ...@@ -22,13 +22,9 @@ import unittest
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers.hf_api import HfApi, HfFolder, ModelInfo, RepoObj from transformers.hf_api import HfApi, HfFolder, ModelInfo, RepoObj
from transformers.testing_utils import require_git_lfs from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test, require_git_lfs
USER = "__DUMMY_TRANSFORMERS_USER__"
PASS = "__DUMMY_TRANSFORMERS_PASS__"
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
ENDPOINT_STAGING_BASIC_AUTH = f"https://{USER}:{PASS}@moon-staging.huggingface.co" ENDPOINT_STAGING_BASIC_AUTH = f"https://{USER}:{PASS}@moon-staging.huggingface.co"
REPO_NAME = f"my-model-{int(time.time())}" REPO_NAME = f"my-model-{int(time.time())}"
...@@ -106,6 +102,7 @@ class HfFolderTest(unittest.TestCase): ...@@ -106,6 +102,7 @@ class HfFolderTest(unittest.TestCase):
@require_git_lfs @require_git_lfs
@is_staging_test
class HfLargefilesTest(HfApiCommonTest): class HfLargefilesTest(HfApiCommonTest):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
...@@ -22,10 +22,22 @@ import tempfile ...@@ -22,10 +22,22 @@ import tempfile
import unittest import unittest
from typing import List, Tuple from typing import List, Tuple
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import is_torch_available, logging from transformers import is_torch_available, logging
from transformers.file_utils import WEIGHTS_NAME from transformers.file_utils import WEIGHTS_NAME
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import CaptureLogger, require_torch, require_torch_multi_gpu, slow, torch_device from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
CaptureLogger,
is_staging_test,
require_torch,
require_torch_multi_gpu,
slow,
torch_device,
)
if is_torch_available(): if is_torch_available():
...@@ -1300,3 +1312,54 @@ class ModelUtilsTest(unittest.TestCase): ...@@ -1300,3 +1312,54 @@ class ModelUtilsTest(unittest.TestCase):
with CaptureLogger(logger) as cl: with CaptureLogger(logger) as cl:
BertModel.from_pretrained(TINY_T5) BertModel.from_pretrained(TINY_T5)
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out) self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
@require_torch
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
try:
cls._api.delete_repo(token=cls._token, name="test-model")
except HTTPError:
pass
try:
cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
except HTTPError:
pass
def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token)
new_model = BertModel.from_pretrained(f"{USER}/test-model")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
tmp_dir,
push_to_hub=True,
repo_name="test-model-org",
use_auth_token=self._token,
organization="valid_org",
)
new_model = BertModel.from_pretrained("valid_org/test-model-org")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
...@@ -24,11 +24,17 @@ import unittest ...@@ -24,11 +24,17 @@ import unittest
from importlib import import_module from importlib import import_module
from typing import List, Tuple from typing import List, Tuple
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import is_tf_available from transformers import is_tf_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import ( from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
_tf_gpu_memory_limit, _tf_gpu_memory_limit,
is_pt_tf_cross_test, is_pt_tf_cross_test,
is_staging_test,
require_onnx, require_onnx,
require_tf, require_tf,
slow, slow,
...@@ -50,6 +56,8 @@ if is_tf_available(): ...@@ -50,6 +56,8 @@ if is_tf_available():
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
TFBertModel,
TFSharedEmbeddings, TFSharedEmbeddings,
tf_top_k_top_p_filtering, tf_top_k_top_p_filtering,
) )
...@@ -1326,3 +1334,62 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1326,3 +1334,62 @@ class UtilsFunctionsTest(unittest.TestCase):
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12) tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx) tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
@require_tf
@is_staging_test
class TFModelPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
try:
cls._api.delete_repo(token=cls._token, name="test-model")
except HTTPError:
pass
try:
cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
except HTTPError:
pass
def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
# Make sure model is properly initialized
_ = model(model.dummy_inputs)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token)
new_model = TFBertModel.from_pretrained(f"{USER}/test-model")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(
tmp_dir,
push_to_hub=True,
repo_name="test-model-org",
use_auth_token=self._token,
organization="valid_org",
)
new_model = TFBertModel.from_pretrained("valid_org/test-model-org")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)
...@@ -20,11 +20,15 @@ import pickle ...@@ -20,11 +20,15 @@ import pickle
import re import re
import shutil import shutil
import tempfile import tempfile
import unittest
from collections import OrderedDict from collections import OrderedDict
from itertools import takewhile from itertools import takewhile
from typing import TYPE_CHECKING, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import ( from transformers import (
BertTokenizer,
PreTrainedTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
...@@ -32,8 +36,12 @@ from transformers import ( ...@@ -32,8 +36,12 @@ from transformers import (
is_torch_available, is_torch_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
get_tests_dir, get_tests_dir,
is_pt_tf_cross_test, is_pt_tf_cross_test,
is_staging_test,
require_tf, require_tf,
require_tokenizers, require_tokenizers,
require_torch, require_torch,
...@@ -2863,3 +2871,53 @@ class TokenizerTesterMixin: ...@@ -2863,3 +2871,53 @@ class TokenizerTesterMixin:
) )
for key in python_output: for key in python_output:
self.assertEqual(python_output[key], rust_output[key]) self.assertEqual(python_output[key], rust_output[key])
@is_staging_test
class TokenzierPushToHubTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
@classmethod
def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
try:
cls._api.delete_repo(token=cls._token, name="test-model")
except HTTPError:
pass
try:
cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
except HTTPError:
pass
def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.save_pretrained(tmp_dir, push_to_hub=True, repo_name="test-model", use_auth_token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-model")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.save_pretrained(
tmp_dir,
push_to_hub=True,
repo_name="test-model-org",
use_auth_token=self._token,
organization="valid_org",
)
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-model-org")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
...@@ -16,16 +16,23 @@ ...@@ -16,16 +16,23 @@
import dataclasses import dataclasses
import gc import gc
import os import os
import re
import tempfile import tempfile
import unittest import unittest
import numpy as np import numpy as np
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers.file_utils import WEIGHTS_NAME from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import ( from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
TestCasePlus, TestCasePlus,
get_tests_dir, get_tests_dir,
is_staging_test,
require_datasets, require_datasets,
require_optuna, require_optuna,
require_ray, require_ray,
...@@ -1081,6 +1088,60 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1081,6 +1088,60 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params) self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
@require_torch
@is_staging_test
class TrainerIntegrationWithHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod
def tearDownClass(cls):
try:
cls._api.delete_repo(token=cls._token, name="test-model")
except HTTPError:
pass
try:
cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
except HTTPError:
pass
def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(output_dir=tmp_dir)
trainer.save_model()
url = trainer.push_to_hub(repo_name="test-model", use_auth_token=self._token)
# Extract repo_name from the url
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
self.assertTrue(re_search is not None)
repo_name = re_search.groups()[0]
self.assertEqual(repo_name, f"{USER}/test-model")
model = RegressionPreTrainedModel.from_pretrained(repo_name)
self.assertEqual(model.a.item(), trainer.model.a.item())
self.assertEqual(model.b.item(), trainer.model.b.item())
def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(output_dir=tmp_dir)
trainer.save_model()
url = trainer.push_to_hub(repo_name="test-model-org", organization="valid_org", use_auth_token=self._token)
# Extract repo_name from the url
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
self.assertTrue(re_search is not None)
repo_name = re_search.groups()[0]
self.assertEqual(repo_name, "valid_org/test-model-org")
model = RegressionPreTrainedModel.from_pretrained("valid_org/test-model-org")
self.assertEqual(model.a.item(), trainer.model.a.item())
self.assertEqual(model.b.item(), trainer.model.b.item())
@require_torch @require_torch
@require_optuna @require_optuna
class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase): class TrainerHyperParameterOptunaIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment