Unverified Commit 38c3cd52 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean up utils.hub using the latest from hf_hub (#18857)

* Clean up utils.hub using the latest from hf_hub

* Adapt test

* Address review comment

* Fix test
parent 17981faf
...@@ -116,7 +116,7 @@ _deps = [ ...@@ -116,7 +116,7 @@ _deps = [
"fugashi>=1.0", "fugashi>=1.0",
"GitPython<3.1.19", "GitPython<3.1.19",
"hf-doc-builder>=0.3.0", "hf-doc-builder>=0.3.0",
"huggingface-hub>=0.8.1,<1.0", "huggingface-hub>=0.9.0,<1.0",
"importlib_metadata", "importlib_metadata",
"ipadic>=1.0.0,<2.0", "ipadic>=1.0.0,<2.0",
"isort>=5.5.4", "isort>=5.5.4",
......
...@@ -22,7 +22,7 @@ deps = { ...@@ -22,7 +22,7 @@ deps = {
"fugashi": "fugashi>=1.0", "fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19", "GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0", "hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.8.1,<1.0", "huggingface-hub": "huggingface-hub>=0.9.0,<1.0",
"importlib_metadata": "importlib_metadata", "importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0", "ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
......
...@@ -21,7 +21,6 @@ import shutil ...@@ -21,7 +21,6 @@ import shutil
import sys import sys
import traceback import traceback
import warnings import warnings
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from uuid import uuid4 from uuid import uuid4
...@@ -39,7 +38,12 @@ from huggingface_hub import ( ...@@ -39,7 +38,12 @@ from huggingface_hub import (
) )
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
from huggingface_hub.file_download import REGEX_COMMIT_HASH from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import (
EntryNotFoundError,
LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers.utils.logging import tqdm from transformers.utils.logging import tqdm
...@@ -249,28 +253,6 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h ...@@ -249,28 +253,6 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
return cached_file if os.path.isfile(cached_file) else None return cached_file if os.path.isfile(cached_file) else None
# If huggingface_hub changes the class of error for this to FileNotFoundError, we will be able to avoid that in the
# future.
LOCAL_FILES_ONLY_HF_ERROR = (
"Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co "
"look-ups and downloads online, set 'local_files_only' to False."
)
# In the future, this ugly contextmanager can be removed when huggingface_hub as a released version where we can
# activate/deactivate progress bars.
@contextmanager
def _patch_hf_hub_tqdm():
"""
A context manager to make huggingface hub use the tqdm version of Transformers (which is controlled by some utils)
in logging.
"""
old_tqdm = huggingface_hub.file_download.tqdm
huggingface_hub.file_download.tqdm = tqdm
yield
huggingface_hub.file_download.tqdm = old_tqdm
def cached_file( def cached_file(
path_or_repo_id: Union[str, os.PathLike], path_or_repo_id: Union[str, os.PathLike],
filename: str, filename: str,
...@@ -375,20 +357,19 @@ def cached_file( ...@@ -375,20 +357,19 @@ def cached_file(
user_agent = http_user_agent(user_agent) user_agent = http_user_agent(user_agent)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
with _patch_hf_hub_tqdm(): resolved_file = hf_hub_download(
resolved_file = hf_hub_download( path_or_repo_id,
path_or_repo_id, filename,
filename, subfolder=None if len(subfolder) == 0 else subfolder,
subfolder=None if len(subfolder) == 0 else subfolder, revision=revision,
revision=revision, cache_dir=cache_dir,
cache_dir=cache_dir, user_agent=user_agent,
user_agent=user_agent, force_download=force_download,
force_download=force_download, proxies=proxies,
proxies=proxies, resume_download=resume_download,
resume_download=resume_download, use_auth_token=use_auth_token,
use_auth_token=use_auth_token, local_files_only=local_files_only,
local_files_only=local_files_only, )
)
except RepositoryNotFoundError: except RepositoryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
...@@ -403,6 +384,19 @@ def cached_file( ...@@ -403,6 +384,19 @@ def cached_file(
"for this model name. Check the model page at " "for this model name. Check the model page at "
f"'https://huggingface.co/{path_or_repo_id}' for available revisions." f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
) )
except LocalEntryNotFoundError:
# We try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
if resolved_file is not None:
return resolved_file
if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors:
return None
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EntryNotFoundError: except EntryNotFoundError:
if not _raise_exceptions_for_missing_entries: if not _raise_exceptions_for_missing_entries:
return None return None
...@@ -421,24 +415,6 @@ def cached_file( ...@@ -421,24 +415,6 @@ def cached_file(
return None return None
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}") raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
except ValueError as err:
# HuggingFace Hub returns a ValueError for a missing file when local_files_only=True we need to catch it here
# This could be caught above along in `EntryNotFoundError` if hf_hub sent a different error message here
if LOCAL_FILES_ONLY_HF_ERROR in err.args[0] and local_files_only and not _raise_exceptions_for_missing_entries:
return None
# Otherwise we try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
if resolved_file is not None:
return resolved_file
if not _raise_exceptions_for_connection_errors:
return None
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
return resolved_file return resolved_file
......
...@@ -30,6 +30,8 @@ from typing import Optional ...@@ -30,6 +30,8 @@ from typing import Optional
from tqdm import auto as tqdm_lib from tqdm import auto as tqdm_lib
import huggingface_hub.utils as hf_hub_utils
_lock = threading.Lock() _lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None _default_handler: Optional[logging.Handler] = None
...@@ -336,9 +338,11 @@ def enable_progress_bar(): ...@@ -336,9 +338,11 @@ def enable_progress_bar():
"""Enable tqdm progress bar.""" """Enable tqdm progress bar."""
global _tqdm_active global _tqdm_active
_tqdm_active = True _tqdm_active = True
hf_hub_utils.enable_progress_bars()
def disable_progress_bar(): def disable_progress_bar():
"""Disable tqdm progress bar.""" """Disable tqdm progress bar."""
global _tqdm_active global _tqdm_active
_tqdm_active = False _tqdm_active = False
hf_hub_utils.disable_progress_bars()
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
import os import os
import unittest import unittest
from unittest.mock import patch
import transformers.models.bart.tokenization_bart import transformers.models.bart.tokenization_bart
from transformers import AutoConfig, logging from huggingface_hub.utils import are_progress_bars_disabled
from transformers import logging
from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context
from transformers.utils.logging import disable_progress_bar, enable_progress_bar from transformers.utils.logging import disable_progress_bar, enable_progress_bar
...@@ -126,14 +126,8 @@ class HfArgumentParserTest(unittest.TestCase): ...@@ -126,14 +126,8 @@ class HfArgumentParserTest(unittest.TestCase):
def test_set_progress_bar_enabled(): def test_set_progress_bar_enabled():
TINY_MODEL = "hf-internal-testing/tiny-random-distilbert" disable_progress_bar()
with patch("tqdm.auto.tqdm") as mock_tqdm: assert are_progress_bars_disabled()
disable_progress_bar()
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
mock_tqdm.assert_not_called()
mock_tqdm.reset_mock() enable_progress_bar()
assert not are_progress_bars_disabled()
enable_progress_bar()
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
mock_tqdm.assert_called()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment