Unverified Commit 558f8543 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Update Transformers to huggingface_hub >= 0.1.0 (#14251)

* Update Transformers to huggingface_hub >= 0.1.0

* Forgot to save...

* Style

* Fix test
parent 519a677e
...@@ -103,8 +103,8 @@ Here is the code to see all available pretrained models on the hub: ...@@ -103,8 +103,8 @@ Here is the code to see all available pretrained models on the hub:
.. code-block:: python .. code-block:: python
from huggingface_hub.hf_api import HfApi from huggingface_hub import list_models
model_list = HfApi().list_models() model_list = list_models()
org = "Helsinki-NLP" org = "Helsinki-NLP"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
suffix = [x.split('/')[1] for x in model_ids] suffix = [x.split('/')[1] for x in model_ids]
......
...@@ -14,7 +14,7 @@ import lightning_base ...@@ -14,7 +14,7 @@ import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main from distillation import distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from huggingface_hub.hf_api import HfApi from huggingface_hub import list_models
from parameterized import parameterized from parameterized import parameterized
from run_eval import generate_summaries_or_translations from run_eval import generate_summaries_or_translations
from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers import AutoConfig, AutoModelForSeq2SeqLM
...@@ -130,7 +130,7 @@ class TestSummarizationDistiller(TestCasePlus): ...@@ -130,7 +130,7 @@ class TestSummarizationDistiller(TestCasePlus):
def test_hub_configs(self): def test_hub_configs(self):
"""I put require_torch_gpu cause I only want this to run with self-scheduled.""" """I put require_torch_gpu cause I only want this to run with self-scheduled."""
model_list = HfApi().list_models() model_list = list_models()
org = "sshleifer" org = "sshleifer"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"] allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
......
...@@ -100,7 +100,7 @@ _deps = [ ...@@ -100,7 +100,7 @@ _deps = [
"flax>=0.3.4", "flax>=0.3.4",
"fugashi>=1.0", "fugashi>=1.0",
"GitPython<3.1.19", "GitPython<3.1.19",
"huggingface-hub>=0.0.17", "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata", "importlib_metadata",
"ipadic>=1.0.0,<2.0", "ipadic>=1.0.0,<2.0",
"isort>=5.5.4", "isort>=5.5.4",
......
...@@ -12,14 +12,12 @@ ...@@ -12,14 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import subprocess import subprocess
import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from getpass import getpass from getpass import getpass
from typing import List, Union from typing import List, Union
from huggingface_hub.hf_api import HfApi, HfFolder from huggingface_hub.hf_api import HfFolder, create_repo, list_repos_objs, login, logout, whoami
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from . import BaseTransformersCLICommand from . import BaseTransformersCLICommand
...@@ -142,7 +140,6 @@ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: ...@@ -142,7 +140,6 @@ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
class BaseUserCommand: class BaseUserCommand:
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
self._api = HfApi()
class LoginCommand(BaseUserCommand): class LoginCommand(BaseUserCommand):
...@@ -166,7 +163,7 @@ class LoginCommand(BaseUserCommand): ...@@ -166,7 +163,7 @@ class LoginCommand(BaseUserCommand):
username = input("Username: ") username = input("Username: ")
password = getpass() password = getpass()
try: try:
token = self._api.login(username, password) token = login(username, password)
except HTTPError as e: except HTTPError as e:
# probably invalid credentials, display error message. # probably invalid credentials, display error message.
print(e) print(e)
...@@ -191,7 +188,7 @@ class WhoamiCommand(BaseUserCommand): ...@@ -191,7 +188,7 @@ class WhoamiCommand(BaseUserCommand):
print("Not logged in") print("Not logged in")
exit() exit()
try: try:
user, orgs = self._api.whoami(token) user, orgs = whoami(token)
print(user) print(user)
if orgs: if orgs:
print(ANSI.bold("orgs: "), ",".join(orgs)) print(ANSI.bold("orgs: "), ",".join(orgs))
...@@ -214,7 +211,7 @@ class LogoutCommand(BaseUserCommand): ...@@ -214,7 +211,7 @@ class LogoutCommand(BaseUserCommand):
print("Not logged in") print("Not logged in")
exit() exit()
HfFolder.delete_token() HfFolder.delete_token()
self._api.logout(token) logout(token)
print("Successfully logged out.") print("Successfully logged out.")
...@@ -222,46 +219,24 @@ class ListObjsCommand(BaseUserCommand): ...@@ -222,46 +219,24 @@ class ListObjsCommand(BaseUserCommand):
def run(self): def run(self):
print( print(
ANSI.red( ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. " "Command removed: it used to be the way to delete an object on S3."
"Please use `huggingface-cli` instead." " We now use a git-based system for storing models and other artifacts."
" Use list-repo-objs instead"
) )
) )
token = HfFolder.get_token() exit(1)
if token is None:
print("Not logged in")
exit(1)
try:
objs = self._api.list_objs(token, organization=self.args.organization)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
if len(objs) == 0:
print("No shared file yet")
exit()
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
print(tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
class DeleteObjCommand(BaseUserCommand): class DeleteObjCommand(BaseUserCommand):
def run(self): def run(self):
print( print(
ANSI.red( ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. " "Command removed: it used to be the way to delete an object on S3."
"Please use `huggingface-cli` instead." " We now use a git-based system for storing models and other artifacts."
" Use delete-repo instead"
) )
) )
token = HfFolder.get_token() exit(1)
if token is None:
print("Not logged in")
exit(1)
try:
self._api.delete_obj(token, filename=self.args.filename, organization=self.args.organization)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
print("Done")
class ListReposObjsCommand(BaseUserCommand): class ListReposObjsCommand(BaseUserCommand):
...@@ -277,7 +252,7 @@ class ListReposObjsCommand(BaseUserCommand): ...@@ -277,7 +252,7 @@ class ListReposObjsCommand(BaseUserCommand):
print("Not logged in") print("Not logged in")
exit(1) exit(1)
try: try:
objs = self._api.list_repos_objs(token, organization=self.args.organization) objs = list_repos_objs(token, organization=self.args.organization)
except HTTPError as e: except HTTPError as e:
print(e) print(e)
print(ANSI.red(e.response.text)) print(ANSI.red(e.response.text))
...@@ -320,7 +295,7 @@ class RepoCreateCommand(BaseUserCommand): ...@@ -320,7 +295,7 @@ class RepoCreateCommand(BaseUserCommand):
) )
print("") print("")
user, _ = self._api.whoami(token) user, _ = whoami(token)
namespace = self.args.organization if self.args.organization is not None else user namespace = self.args.organization if self.args.organization is not None else user
full_name = f"{namespace}/{self.args.name}" full_name = f"{namespace}/{self.args.name}"
print(f"You are about to create {ANSI.bold(full_name)}") print(f"You are about to create {ANSI.bold(full_name)}")
...@@ -331,7 +306,7 @@ class RepoCreateCommand(BaseUserCommand): ...@@ -331,7 +306,7 @@ class RepoCreateCommand(BaseUserCommand):
print("Abort") print("Abort")
exit() exit()
try: try:
url = self._api.create_repo(token, name=self.args.name, organization=self.args.organization) url = create_repo(token, name=self.args.name, organization=self.args.organization)
except HTTPError as e: except HTTPError as e:
print(e) print(e)
print(ANSI.red(e.response.text)) print(ANSI.red(e.response.text))
...@@ -356,73 +331,12 @@ class DeprecatedUploadCommand(BaseUserCommand): ...@@ -356,73 +331,12 @@ class DeprecatedUploadCommand(BaseUserCommand):
class UploadCommand(BaseUserCommand): class UploadCommand(BaseUserCommand):
def walk_dir(self, rel_path):
"""
Recursively list all files in a folder.
"""
entries: List[os.DirEntry] = list(os.scandir(rel_path))
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # (filepath, filename)
for f in entries:
if f.is_dir():
files += self.walk_dir(f.path)
return files
def run(self): def run(self):
print( print(
ANSI.red( ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. " "Deprecated: used to be the way to upload a model to S3."
"Please use `huggingface-cli` instead." " We now use a git-based system for storing models and other artifacts."
" Use the `repo create` command instead."
) )
) )
token = HfFolder.get_token() exit(1)
if token is None:
print("Not logged in")
exit(1)
local_path = os.path.abspath(self.args.path)
if os.path.isdir(local_path):
if self.args.filename is not None:
raise ValueError("Cannot specify a filename override when uploading a folder.")
rel_path = os.path.basename(local_path)
files = self.walk_dir(rel_path)
elif os.path.isfile(local_path):
filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path)
files = [(local_path, filename)]
else:
raise ValueError(f"Not a valid file or directory: {local_path}")
if sys.platform == "win32":
files = [(filepath, filename.replace(os.sep, "/")) for filepath, filename in files]
if len(files) > UPLOAD_MAX_FILES:
print(
f"About to upload {ANSI.bold(len(files))} files to S3. This is probably wrong. Please filter files "
"before uploading."
)
exit(1)
user, _ = self._api.whoami(token)
namespace = self.args.organization if self.args.organization is not None else user
for filepath, filename in files:
print(
f"About to upload file {ANSI.bold(filepath)} to S3 under filename {ANSI.bold(filename)} and namespace "
f"{ANSI.bold(namespace)}"
)
if not self.args.yes:
choice = input("Proceed? [Y/n] ").lower()
if not (choice == "" or choice == "y" or choice == "yes"):
print("Abort")
exit()
print(ANSI.bold("Uploading... This might take a while if files are large"))
for filepath, filename in files:
try:
access_url = self._api.presign_and_upload(
token=token, filename=filename, filepath=filepath, organization=self.args.organization
)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
print("Your file now lives at:")
print(access_url)
...@@ -18,7 +18,7 @@ deps = { ...@@ -18,7 +18,7 @@ deps = {
"flax": "flax>=0.3.4", "flax": "flax>=0.3.4",
"fugashi": "fugashi>=1.0", "fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19", "GitPython": "GitPython<3.1.19",
"huggingface-hub": "huggingface-hub>=0.0.17", "huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata": "importlib_metadata", "importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0", "ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
......
...@@ -48,7 +48,7 @@ from tqdm.auto import tqdm ...@@ -48,7 +48,7 @@ from tqdm.auto import tqdm
import requests import requests
from filelock import FileLock from filelock import FileLock
from huggingface_hub import HfApi, HfFolder, Repository from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
from transformers.utils.versions import importlib_metadata from transformers.utils.versions import importlib_metadata
from . import __version__ from . import __version__
...@@ -1808,17 +1808,14 @@ def get_list_of_files( ...@@ -1808,17 +1808,14 @@ def get_list_of_files(
if is_offline_mode() or local_files_only: if is_offline_mode() or local_files_only:
return [] return []
# Otherwise we grab the token and use the model_info method. # Otherwise we grab the token and use the list_repo_files method.
if isinstance(use_auth_token, str): if isinstance(use_auth_token, str):
token = use_auth_token token = use_auth_token
elif use_auth_token is True: elif use_auth_token is True:
token = HfFolder.get_token() token = HfFolder.get_token()
else: else:
token = None token = None
model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info( return list_repo_files(path_or_repo, revision=revision, token=token)
path_or_repo, revision=revision, token=token
)
return [f.rfilename for f in model_info.siblings]
class cached_property(property): class cached_property(property):
...@@ -2308,7 +2305,7 @@ class PushToHubMixin: ...@@ -2308,7 +2305,7 @@ class PushToHubMixin:
token = None token = None
# Special provision for the test endpoint (CI) # Special provision for the test endpoint (CI)
return HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).create_repo( return create_repo(
token, token,
repo_name, repo_name,
organization=organization, organization=organization,
...@@ -2366,7 +2363,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -2366,7 +2363,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
if token is None: if token is None:
token = HfFolder.get_token() token = HfFolder.get_token()
if organization is None: if organization is None:
username = HfApi().whoami(token)["name"] username = whoami(token)["name"]
return f"{username}/{model_id}" return f"{username}/{model_id}"
else: else:
return f"{organization}/{model_id}" return f"{organization}/{model_id}"
......
...@@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Union ...@@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Union
import requests import requests
import yaml import yaml
from huggingface_hub import HfApi from huggingface_hub import model_info
from . import __version__ from . import __version__
from .file_utils import ( from .file_utils import (
...@@ -387,8 +387,8 @@ class TrainingSummary: ...@@ -387,8 +387,8 @@ class TrainingSummary:
and len(self.finetuned_from) > 0 and len(self.finetuned_from) > 0
): ):
try: try:
model_info = HfApi().model_info(self.finetuned_from) info = model_info(self.finetuned_from)
for tag in model_info.tags: for tag in info.tags:
if tag.startswith("license:"): if tag.startswith("license:"):
self.license = tag[8:] self.license = tag[8:]
except requests.exceptions.HTTPError: except requests.exceptions.HTTPError:
......
...@@ -27,7 +27,7 @@ import torch ...@@ -27,7 +27,7 @@ import torch
from torch import nn from torch import nn
from tqdm import tqdm from tqdm import tqdm
from huggingface_hub.hf_api import HfApi from huggingface_hub.hf_api import list_models
from transformers import MarianConfig, MarianMTModel, MarianTokenizer from transformers import MarianConfig, MarianMTModel, MarianTokenizer
...@@ -64,8 +64,7 @@ def load_layers_(layer_lst: nn.ModuleList, opus_state: dict, converter, is_decod ...@@ -64,8 +64,7 @@ def load_layers_(layer_lst: nn.ModuleList, opus_state: dict, converter, is_decod
def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]: def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]:
"""Find models that can accept src_lang as input and return tgt_lang as output.""" """Find models that can accept src_lang as input and return tgt_lang as output."""
prefix = "Helsinki-NLP/opus-mt-" prefix = "Helsinki-NLP/opus-mt-"
api = HfApi() model_list = list_models()
model_list = api.list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")] model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")]
src_and_targ = [ src_and_targ = [
remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m
......
...@@ -19,11 +19,11 @@ import os ...@@ -19,11 +19,11 @@ import os
import tempfile import tempfile
import unittest import unittest
from huggingface_hub import HfApi from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import BertConfig, GPT2Config, is_torch_available from transformers import BertConfig, GPT2Config, is_torch_available
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test from transformers.testing_utils import PASS, USER, is_staging_test
config_common_kwargs = { config_common_kwargs = {
...@@ -194,18 +194,17 @@ class ConfigTester(object): ...@@ -194,18 +194,17 @@ class ConfigTester(object):
class ConfigPushToHubTester(unittest.TestCase): class ConfigPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: try:
cls._api.delete_repo(token=cls._token, name="test-config") delete_repo(token=cls._token, name="test-config")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-config-org", organization="valid_org") delete_repo(token=cls._token, name="test-config-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass
......
...@@ -28,13 +28,12 @@ from typing import Dict, List, Tuple ...@@ -28,13 +28,12 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import transformers import transformers
from huggingface_hub import HfApi, Repository from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import ( from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS, PASS,
USER, USER,
CaptureLogger, CaptureLogger,
...@@ -2122,23 +2121,22 @@ class FakeModel(PreTrainedModel): ...@@ -2122,23 +2121,22 @@ class FakeModel(PreTrainedModel):
class ModelPushToHubTester(unittest.TestCase): class ModelPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: try:
cls._api.delete_repo(token=cls._token, name="test-model") delete_repo(token=cls._token, name="test-model")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org") delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-dynamic-model") delete_repo(token=cls._token, name="test-dynamic-model")
except HTTPError: except HTTPError:
pass pass
......
...@@ -22,19 +22,11 @@ from typing import List, Tuple ...@@ -22,19 +22,11 @@ from typing import List, Tuple
import numpy as np import numpy as np
import transformers import transformers
from huggingface_hub import HfApi from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import BertConfig, is_flax_available, is_torch_available from transformers import BertConfig, is_flax_available, is_torch_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import ( from transformers.testing_utils import PASS, USER, CaptureLogger, is_pt_flax_cross_test, is_staging_test, require_flax
ENDPOINT_STAGING,
PASS,
USER,
CaptureLogger,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
)
from transformers.utils import logging from transformers.utils import logging
...@@ -627,18 +619,17 @@ class FlaxModelTesterMixin: ...@@ -627,18 +619,17 @@ class FlaxModelTesterMixin:
class FlaxModelPushToHubTester(unittest.TestCase): class FlaxModelPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: try:
cls._api.delete_repo(token=cls._token, name="test-model-flax") delete_repo(token=cls._token, name="test-model-flax")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org") delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import tempfile import tempfile
import unittest import unittest
from huggingface_hub.hf_api import HfApi from huggingface_hub.hf_api import list_models
from transformers import MarianConfig, is_torch_available from transformers import MarianConfig, is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
...@@ -296,7 +296,7 @@ class ModelManagementTests(unittest.TestCase): ...@@ -296,7 +296,7 @@ class ModelManagementTests(unittest.TestCase):
@slow @slow
@require_torch @require_torch
def test_model_names(self): def test_model_names(self):
model_list = HfApi().list_models() model_list = list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)] model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]
bad_model_ids = [mid for mid in model_ids if "+" in model_ids] bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
self.assertListEqual([], bad_model_ids) self.assertListEqual([], bad_model_ids)
......
...@@ -24,12 +24,11 @@ import unittest ...@@ -24,12 +24,11 @@ import unittest
from importlib import import_module from importlib import import_module
from typing import List, Tuple from typing import List, Tuple
from huggingface_hub import HfApi from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import is_tf_available from transformers import is_tf_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import ( from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS, PASS,
USER, USER,
CaptureLogger, CaptureLogger,
...@@ -1530,18 +1529,17 @@ class UtilsFunctionsTest(unittest.TestCase): ...@@ -1530,18 +1529,17 @@ class UtilsFunctionsTest(unittest.TestCase):
class TFModelPushToHubTester(unittest.TestCase): class TFModelPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: try:
cls._api.delete_repo(token=cls._token, name="test-model-tf") delete_repo(token=cls._token, name="test-model-tf")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-model-tf-org", organization="valid_org") delete_repo(token=cls._token, name="test-model-tf-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass
......
...@@ -27,7 +27,7 @@ from collections import OrderedDict ...@@ -27,7 +27,7 @@ from collections import OrderedDict
from itertools import takewhile from itertools import takewhile
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from huggingface_hub import HfApi from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import ( from transformers import (
AlbertTokenizer, AlbertTokenizer,
...@@ -44,7 +44,6 @@ from transformers import ( ...@@ -44,7 +44,6 @@ from transformers import (
is_torch_available, is_torch_available,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS, PASS,
USER, USER,
get_tests_dir, get_tests_dir,
...@@ -3520,18 +3519,17 @@ class TokenizerPushToHubTester(unittest.TestCase): ...@@ -3520,18 +3519,17 @@ class TokenizerPushToHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
try: try:
cls._api.delete_repo(token=cls._token, name="test-tokenizer") delete_repo(token=cls._token, name="test-tokenizer")
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-tokenizer-org", organization="valid_org") delete_repo(token=cls._token, name="test-tokenizer-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass
......
...@@ -26,7 +26,7 @@ from pathlib import Path ...@@ -26,7 +26,7 @@ from pathlib import Path
import numpy as np import numpy as np
from huggingface_hub import HfApi, Repository from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
...@@ -1307,19 +1307,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1307,19 +1307,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
class TrainerIntegrationWithHubTester(unittest.TestCase): class TrainerIntegrationWithHubTester(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls._api = HfApi(endpoint=ENDPOINT_STAGING) cls._token = login(username=USER, password=PASS)
cls._token = cls._api.login(username=USER, password=PASS)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]: for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]:
try: try:
cls._api.delete_repo(token=cls._token, name=model) delete_repo(token=cls._token, name=model)
except HTTPError: except HTTPError:
pass pass
try: try:
cls._api.delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org") delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org")
except HTTPError: except HTTPError:
pass pass
...@@ -1396,6 +1395,10 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ...@@ -1396,6 +1395,10 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
print(commits, len(commits)) print(commits, len(commits))
def test_push_to_hub_with_saves_each_n_steps(self): def test_push_to_hub_with_saves_each_n_steps(self):
num_gpus = max(1, get_gpu_count())
if num_gpus > 2:
return
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer( trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-step"), output_dir=os.path.join(tmp_dir, "test-trainer-step"),
...@@ -1409,7 +1412,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ...@@ -1409,7 +1412,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token) _ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token)
commits = self.get_commit_history(tmp_dir) commits = self.get_commit_history(tmp_dir)
expected_commits = [f"Training in progress, step {i}" for i in range(20, 0, -5)] total_steps = 20 // num_gpus
expected_commits = [f"Training in progress, step {i}" for i in range(total_steps, 0, -5)]
expected_commits.append("initial commit") expected_commits.append("initial commit")
self.assertListEqual(commits, expected_commits) self.assertListEqual(commits, expected_commits)
print(commits, len(commits)) print(commits, len(commits))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment