Unverified Commit 9a498c37 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Rely on huggingface_hub for common tools (#13100)

* Remove hf_api module and use hugginface_hub

* Style

* Fix to test_fetcher

* Quality
parent 6900dded
......@@ -103,8 +103,8 @@ Here is the code to see all available pretrained models on the hub:
.. code-block:: python
from transformers.hf_api import HfApi
model_list = HfApi().model_list()
from huggingface_hub.hf_api import HfApi
model_list = HfApi().list_models()
org = "Helsinki-NLP"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
suffix = [x.split('/')[1] for x in model_ids]
......
......@@ -14,10 +14,10 @@ import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main
from finetune import SummarizationModule, main
from huggingface_hub.hf_api import HfApi
from parameterized import parameterized
from run_eval import generate_summaries_or_translations
from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
from utils import label_smoothed_nll_loss, lmap, load_json
......@@ -130,7 +130,7 @@ class TestSummarizationDistiller(TestCasePlus):
def test_hub_configs(self):
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
model_list = HfApi().model_list()
model_list = HfApi().list_models()
org = "sshleifer"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
......
......@@ -15,14 +15,13 @@
import os
import subprocess
import sys
import warnings
from argparse import ArgumentParser
from getpass import getpass
from typing import List, Union
from huggingface_hub.hf_api import HfApi, HfFolder
from requests.exceptions import HTTPError
from ..hf_api import HfApi, HfFolder
from . import BaseTransformersCLICommand
......@@ -148,6 +147,12 @@ class BaseUserCommand:
class LoginCommand(BaseUserCommand):
def run(self):
print(
ANSI.red(
"WARNING! `transformers-cli login` is deprecated and will be removed in v5. Please use "
"`huggingface-cli login` instead."
)
)
print( # docstyle-ignore
"""
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
......@@ -175,6 +180,12 @@ class LoginCommand(BaseUserCommand):
class WhoamiCommand(BaseUserCommand):
def run(self):
print(
ANSI.red(
"WARNING! `transformers-cli whoami` is deprecated and will be removed in v5. Please use "
"`huggingface-cli whoami` instead."
)
)
token = HfFolder.get_token()
if token is None:
print("Not logged in")
......@@ -192,6 +203,12 @@ class WhoamiCommand(BaseUserCommand):
class LogoutCommand(BaseUserCommand):
def run(self):
print(
ANSI.red(
"WARNING! `transformers-cli logout` is deprecated and will be removed in v5. Please use "
"`huggingface-cli logout` instead."
)
)
token = HfFolder.get_token()
if token is None:
print("Not logged in")
......@@ -203,8 +220,11 @@ class LogoutCommand(BaseUserCommand):
class ListObjsCommand(BaseUserCommand):
def run(self):
warnings.warn(
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
print(
ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. "
"Please use `huggingface-cli` instead."
)
)
token = HfFolder.get_token()
if token is None:
......@@ -225,8 +245,11 @@ class ListObjsCommand(BaseUserCommand):
class DeleteObjCommand(BaseUserCommand):
def run(self):
warnings.warn(
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
print(
ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. "
"Please use `huggingface-cli` instead."
)
)
token = HfFolder.get_token()
if token is None:
......@@ -243,8 +266,11 @@ class DeleteObjCommand(BaseUserCommand):
class ListReposObjsCommand(BaseUserCommand):
def run(self):
warnings.warn(
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
print(
ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. "
"Please use `huggingface-cli` instead."
)
)
token = HfFolder.get_token()
if token is None:
......@@ -265,8 +291,11 @@ class ListReposObjsCommand(BaseUserCommand):
class RepoCreateCommand(BaseUserCommand):
def run(self):
warnings.warn(
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
print(
ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. "
"Please use `huggingface-cli` instead."
)
)
token = HfFolder.get_token()
if token is None:
......@@ -339,8 +368,11 @@ class UploadCommand(BaseUserCommand):
return files
def run(self):
warnings.warn(
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
print(
ANSI.red(
"WARNING! Managing repositories through transformers-cli is deprecated. "
"Please use `huggingface-cli` instead."
)
)
token = HfFolder.get_token()
if token is None:
......
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from os.path import expanduser
from typing import Dict, List, Optional, Tuple
from tqdm import tqdm
import requests
ENDPOINT = "https://huggingface.co"
class RepoObj:
"""
HuggingFace git-based system, data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):
self.filename = filename
self.lastModified = lastModified
self.commit = commit
self.size = size
class ModelSibling:
"""
Data structure that represents a public file inside a model, accessible from huggingface.co
"""
def __init__(self, rfilename: str, **kwargs):
self.rfilename = rfilename # filename relative to the model root
for k, v in kwargs.items():
setattr(self, k, v)
class ModelInfo:
"""
Info about a public model accessible from huggingface.co
"""
def __init__(
self,
modelId: Optional[str] = None, # id of model
tags: List[str] = [],
pipeline_tag: Optional[str] = None,
siblings: Optional[List[Dict]] = None, # list of files that constitute the model
**kwargs
):
self.modelId = modelId
self.tags = tags
self.pipeline_tag = pipeline_tag
self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None
for k, v in kwargs.items():
setattr(self, k, v)
class HfApi:
def __init__(self, endpoint=None):
self.endpoint = endpoint if endpoint is not None else ENDPOINT
def login(self, username: str, password: str) -> str:
"""
Call HF API to sign in a user and get a token if credentials are valid.
Outputs: token if credentials are valid
Throws: requests.exceptions.HTTPError if credentials are invalid
"""
path = f"{self.endpoint}/api/login"
r = requests.post(path, json={"username": username, "password": password})
r.raise_for_status()
d = r.json()
return d["token"]
def whoami(self, token: str) -> Tuple[str, List[str]]:
"""
Call HF API to know "whoami"
"""
path = f"{self.endpoint}/api/whoami"
r = requests.get(path, headers={"authorization": f"Bearer {token}"})
r.raise_for_status()
d = r.json()
return d["user"], d["orgs"]
def logout(self, token: str) -> None:
"""
Call HF API to log out.
"""
path = f"{self.endpoint}/api/logout"
r = requests.post(path, headers={"authorization": f"Bearer {token}"})
r.raise_for_status()
def model_list(self) -> List[ModelInfo]:
"""
Get the public list of all the models on huggingface.co
"""
path = f"{self.endpoint}/api/models"
r = requests.get(path)
r.raise_for_status()
d = r.json()
return [ModelInfo(**x) for x in d]
def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[RepoObj]:
"""
HuggingFace git-based system, used for models.
Call HF API to list all stored files for user (or one of their organizations).
"""
path = f"{self.endpoint}/api/repos/ls"
params = {"organization": organization} if organization is not None else None
r = requests.get(path, params=params, headers={"authorization": f"Bearer {token}"})
r.raise_for_status()
d = r.json()
return [RepoObj(**x) for x in d]
def create_repo(
self,
token: str,
name: str,
organization: Optional[str] = None,
private: Optional[bool] = None,
exist_ok=False,
lfsmultipartthresh: Optional[int] = None,
) -> str:
"""
HuggingFace git-based system, used for models.
Call HF API to create a whole repo.
Params:
private: Whether the model repo should be private (requires a paid huggingface.co account)
exist_ok: Do not raise an error if repo already exists
lfsmultipartthresh: Optional: internal param for testing purposes.
"""
path = f"{self.endpoint}/api/repos/create"
json = {"name": name, "organization": organization, "private": private}
if lfsmultipartthresh is not None:
json["lfsmultipartthresh"] = lfsmultipartthresh
r = requests.post(
path,
headers={"authorization": f"Bearer {token}"},
json=json,
)
if exist_ok and r.status_code == 409:
return ""
r.raise_for_status()
d = r.json()
return d["url"]
def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
"""
HuggingFace git-based system, used for models.
Call HF API to delete a whole repo.
CAUTION(this is irreversible).
"""
path = f"{self.endpoint}/api/repos/delete"
r = requests.delete(
path,
headers={"authorization": f"Bearer {token}"},
json={"name": name, "organization": organization},
)
r.raise_for_status()
class TqdmProgressFileReader:
"""
Wrap an io.BufferedReader `f` (such as the output of `open(…, "rb")`) and override `f.read()` so as to display a
tqdm progress bar.
see github.com/huggingface/transformers/pull/2078#discussion_r354739608 for implementation details.
"""
def __init__(self, f: io.BufferedReader):
self.f = f
self.total_size = os.fstat(f.fileno()).st_size
self.pbar = tqdm(total=self.total_size, leave=False)
self.read = f.read
f.read = self._read
def _read(self, n=-1):
self.pbar.update(n)
return self.read(n)
def close(self):
self.pbar.close()
class HfFolder:
path_token = expanduser("~/.huggingface/token")
@classmethod
def save_token(cls, token):
"""
Save token, creating folder as needed.
"""
os.makedirs(os.path.dirname(cls.path_token), exist_ok=True)
with open(cls.path_token, "w+") as f:
f.write(token)
@classmethod
def get_token(cls):
"""
Get token or None if not existent.
"""
try:
with open(cls.path_token, "r") as f:
return f.read()
except FileNotFoundError:
pass
@classmethod
def delete_token(cls):
"""
Delete token. Do not fail if token does not exist.
"""
try:
os.remove(cls.path_token)
except FileNotFoundError:
pass
......@@ -27,8 +27,8 @@ import torch
from torch import nn
from tqdm import tqdm
from huggingface_hub.hf_api import HfApi
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
from transformers.hf_api import HfApi
def remove_suffix(text: str, suffix: str):
......@@ -65,7 +65,7 @@ def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]:
"""Find models that can accept src_lang as input and return tgt_lang as output."""
prefix = "Helsinki-NLP/opus-mt-"
api = HfApi()
model_list = api.model_list()
model_list = api.list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")]
src_and_targ = [
remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m
......
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import subprocess
import time
import unittest
from requests.exceptions import HTTPError
from transformers.hf_api import HfApi, HfFolder, ModelInfo, RepoObj
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test, require_git_lfs
ENDPOINT_STAGING_BASIC_AUTH = f"https://{USER}:{PASS}@moon-staging.huggingface.co"
REPO_NAME = f"my-model-{int(time.time())}"
REPO_NAME_LARGE_FILE = f"my-model-largefiles-{int(time.time())}"
WORKING_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/working_repo")
LARGE_FILE_14MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.epub"
LARGE_FILE_18MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.pdf"
class HfApiCommonTest(unittest.TestCase):
_api = HfApi(endpoint=ENDPOINT_STAGING)
class HfApiLoginTest(HfApiCommonTest):
def test_login_invalid(self):
with self.assertRaises(HTTPError):
self._api.login(username=USER, password="fake")
def test_login_valid(self):
token = self._api.login(username=USER, password=PASS)
self.assertIsInstance(token, str)
class HfApiEndpointsTest(HfApiCommonTest):
@classmethod
def setUpClass(cls):
"""
Share this valid token in all tests below.
"""
cls._token = cls._api.login(username=USER, password=PASS)
def test_whoami(self):
user, orgs = self._api.whoami(token=self._token)
self.assertEqual(user, USER)
self.assertIsInstance(orgs, list)
def test_list_repos_objs(self):
objs = self._api.list_repos_objs(token=self._token)
self.assertIsInstance(objs, list)
if len(objs) > 0:
o = objs[-1]
self.assertIsInstance(o, RepoObj)
def test_create_and_delete_repo(self):
self._api.create_repo(token=self._token, name=REPO_NAME)
self._api.delete_repo(token=self._token, name=REPO_NAME)
class HfApiPublicTest(unittest.TestCase):
def test_staging_model_list(self):
_api = HfApi(endpoint=ENDPOINT_STAGING)
_ = _api.model_list()
def test_model_list(self):
_api = HfApi()
models = _api.model_list()
self.assertGreater(len(models), 100)
self.assertIsInstance(models[0], ModelInfo)
class HfFolderTest(unittest.TestCase):
def test_token_workflow(self):
"""
Test the whole token save/get/delete workflow,
with the desired behavior with respect to non-existent tokens.
"""
token = f"token-{int(time.time())}"
HfFolder.save_token(token)
self.assertEqual(HfFolder.get_token(), token)
HfFolder.delete_token()
HfFolder.delete_token()
# ^^ not an error, we test that the
# second call does not fail.
self.assertEqual(HfFolder.get_token(), None)
@require_git_lfs
@is_staging_test
class HfLargefilesTest(HfApiCommonTest):
@classmethod
def setUpClass(cls):
"""
Share this valid token in all tests below.
"""
cls._token = cls._api.login(username=USER, password=PASS)
def setUp(self):
try:
shutil.rmtree(WORKING_REPO_DIR)
except FileNotFoundError:
pass
def tearDown(self):
self._api.delete_repo(token=self._token, name=REPO_NAME_LARGE_FILE)
def setup_local_clone(self, REMOTE_URL):
REMOTE_URL_AUTH = REMOTE_URL.replace(ENDPOINT_STAGING, ENDPOINT_STAGING_BASIC_AUTH)
subprocess.run(["git", "clone", REMOTE_URL_AUTH, WORKING_REPO_DIR], check=True, capture_output=True)
subprocess.run(["git", "lfs", "track", "*.pdf"], check=True, cwd=WORKING_REPO_DIR)
subprocess.run(["git", "lfs", "track", "*.epub"], check=True, cwd=WORKING_REPO_DIR)
def test_end_to_end_thresh_6M(self):
REMOTE_URL = self._api.create_repo(
token=self._token, name=REPO_NAME_LARGE_FILE, lfsmultipartthresh=6 * 10 ** 6
)
self.setup_local_clone(REMOTE_URL)
subprocess.run(["wget", LARGE_FILE_18MB], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
subprocess.run(["git", "add", "*"], check=True, cwd=WORKING_REPO_DIR)
subprocess.run(["git", "commit", "-m", "commit message"], check=True, cwd=WORKING_REPO_DIR)
# This will fail as we haven't set up our custom transfer agent yet.
failed_process = subprocess.run(["git", "push"], capture_output=True, cwd=WORKING_REPO_DIR)
self.assertEqual(failed_process.returncode, 1)
self.assertIn("transformers-cli lfs-enable-largefiles", failed_process.stderr.decode())
# ^ Instructions on how to fix this are included in the error message.
subprocess.run(["transformers-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True)
start_time = time.time()
subprocess.run(["git", "push"], check=True, cwd=WORKING_REPO_DIR)
print("took", time.time() - start_time)
# To be 100% sure, let's download the resolved file
pdf_url = f"{REMOTE_URL}/resolve/main/progit.pdf"
DEST_FILENAME = "uploaded.pdf"
subprocess.run(["wget", pdf_url, "-O", DEST_FILENAME], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
dest_filesize = os.stat(os.path.join(WORKING_REPO_DIR, DEST_FILENAME)).st_size
self.assertEqual(dest_filesize, 18685041)
def test_end_to_end_thresh_16M(self):
# Here we'll push one multipart and one non-multipart file in the same commit, and see what happens
REMOTE_URL = self._api.create_repo(
token=self._token, name=REPO_NAME_LARGE_FILE, lfsmultipartthresh=16 * 10 ** 6
)
self.setup_local_clone(REMOTE_URL)
subprocess.run(["wget", LARGE_FILE_18MB], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
subprocess.run(["wget", LARGE_FILE_14MB], check=True, capture_output=True, cwd=WORKING_REPO_DIR)
subprocess.run(["git", "add", "*"], check=True, cwd=WORKING_REPO_DIR)
subprocess.run(["git", "commit", "-m", "both files in same commit"], check=True, cwd=WORKING_REPO_DIR)
subprocess.run(["transformers-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True)
start_time = time.time()
subprocess.run(["git", "push"], check=True, cwd=WORKING_REPO_DIR)
print("took", time.time() - start_time)
......@@ -17,9 +17,9 @@
import tempfile
import unittest
from huggingface_hub.hf_api import HfApi
from transformers import MarianConfig, is_torch_available
from transformers.file_utils import cached_property
from transformers.hf_api import HfApi
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
......@@ -296,7 +296,7 @@ class ModelManagementTests(unittest.TestCase):
@slow
@require_torch
def test_model_names(self):
model_list = HfApi().model_list()
model_list = HfApi().list_models()
model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]
bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
self.assertListEqual([], bad_model_ids)
......
......@@ -412,6 +412,8 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None):
# Remove duplicates
test_files_to_run = sorted(list(set(test_files_to_run)))
# Make sure we did not end up with a test file that was removed
test_files_to_run = [f for f in test_files_to_run if os.path.isfile(f) or os.path.isdir(f)]
if filters is not None:
for filter in filters:
test_files_to_run = [f for f in test_files_to_run if f.startswith(filter)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment