Unverified Commit 38c3cd52 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean up utils.hub using the latest from hf_hub (#18857)

* Clean up utils.hub using the latest from hf_hub

* Adapt test

* Address review comment

* Fix test
parent 17981faf
......@@ -116,7 +116,7 @@ _deps = [
"fugashi>=1.0",
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.8.1,<1.0",
"huggingface-hub>=0.9.0,<1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
......
......@@ -22,7 +22,7 @@ deps = {
"fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.8.1,<1.0",
"huggingface-hub": "huggingface-hub>=0.9.0,<1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
......
......@@ -21,7 +21,6 @@ import shutil
import sys
import traceback
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from uuid import uuid4
......@@ -39,7 +38,12 @@ from huggingface_hub import (
)
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from huggingface_hub.utils import (
EntryNotFoundError,
LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from requests.exceptions import HTTPError
from transformers.utils.logging import tqdm
......@@ -249,28 +253,6 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
return cached_file if os.path.isfile(cached_file) else None
# If huggingface_hub changes the class of error for this to FileNotFoundError, we will be able to avoid that in the
# future.
LOCAL_FILES_ONLY_HF_ERROR = (
"Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co "
"look-ups and downloads online, set 'local_files_only' to False."
)
# In the future, this ugly contextmanager can be removed when huggingface_hub as a released version where we can
# activate/deactivate progress bars.
@contextmanager
def _patch_hf_hub_tqdm():
"""
A context manager to make huggingface hub use the tqdm version of Transformers (which is controlled by some utils)
in logging.
"""
old_tqdm = huggingface_hub.file_download.tqdm
huggingface_hub.file_download.tqdm = tqdm
yield
huggingface_hub.file_download.tqdm = old_tqdm
def cached_file(
path_or_repo_id: Union[str, os.PathLike],
filename: str,
......@@ -375,20 +357,19 @@ def cached_file(
user_agent = http_user_agent(user_agent)
try:
# Load from URL or cache if already cached
with _patch_hf_hub_tqdm():
resolved_file = hf_hub_download(
path_or_repo_id,
filename,
subfolder=None if len(subfolder) == 0 else subfolder,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
resolved_file = hf_hub_download(
path_or_repo_id,
filename,
subfolder=None if len(subfolder) == 0 else subfolder,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
except RepositoryNotFoundError:
raise EnvironmentError(
......@@ -403,6 +384,19 @@ def cached_file(
"for this model name. Check the model page at "
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
)
except LocalEntryNotFoundError:
# We try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
if resolved_file is not None:
return resolved_file
if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors:
return None
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EntryNotFoundError:
if not _raise_exceptions_for_missing_entries:
return None
......@@ -421,24 +415,6 @@ def cached_file(
return None
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
except ValueError as err:
# HuggingFace Hub returns a ValueError for a missing file when local_files_only=True we need to catch it here
# This could be caught above along in `EntryNotFoundError` if hf_hub sent a different error message here
if LOCAL_FILES_ONLY_HF_ERROR in err.args[0] and local_files_only and not _raise_exceptions_for_missing_entries:
return None
# Otherwise we try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
if resolved_file is not None:
return resolved_file
if not _raise_exceptions_for_connection_errors:
return None
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
return resolved_file
......
......@@ -30,6 +30,8 @@ from typing import Optional
from tqdm import auto as tqdm_lib
import huggingface_hub.utils as hf_hub_utils
_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
......@@ -336,9 +338,11 @@ def enable_progress_bar():
"""Enable tqdm progress bar."""
global _tqdm_active
_tqdm_active = True
hf_hub_utils.enable_progress_bars()
def disable_progress_bar():
"""Disable tqdm progress bar."""
global _tqdm_active
_tqdm_active = False
hf_hub_utils.disable_progress_bars()
......@@ -14,10 +14,10 @@
import os
import unittest
from unittest.mock import patch
import transformers.models.bart.tokenization_bart
from transformers import AutoConfig, logging
from huggingface_hub.utils import are_progress_bars_disabled
from transformers import logging
from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
......@@ -126,14 +126,8 @@ class HfArgumentParserTest(unittest.TestCase):
def test_set_progress_bar_enabled():
TINY_MODEL = "hf-internal-testing/tiny-random-distilbert"
with patch("tqdm.auto.tqdm") as mock_tqdm:
disable_progress_bar()
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
mock_tqdm.assert_not_called()
disable_progress_bar()
assert are_progress_bars_disabled()
mock_tqdm.reset_mock()
enable_progress_bar()
_ = AutoConfig.from_pretrained(TINY_MODEL, force_download=True)
mock_tqdm.assert_called()
enable_progress_bar()
assert not are_progress_bars_disabled()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment