Unverified Commit 06eda5b2 authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Raise initial HTTPError if pipeline is not cached locally (#4230)

* Raise initial HTTPError if pipeline is not cached locally

* make style
parent 8e5921ca
...@@ -1248,6 +1248,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1248,6 +1248,7 @@ class DiffusionPipeline(ConfigMixin):
allow_patterns = None allow_patterns = None
ignore_patterns = None ignore_patterns = None
model_info_call_error: Optional[Exception] = None
if not local_files_only: if not local_files_only:
try: try:
info = model_info( info = model_info(
...@@ -1258,6 +1259,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1258,6 +1259,7 @@ class DiffusionPipeline(ConfigMixin):
except HTTPError as e: except HTTPError as e:
logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True local_files_only = True
model_info_call_error = e # save error to reraise it if model is not cached locally
if not local_files_only: if not local_files_only:
config_file = hf_hub_download( config_file = hf_hub_download(
...@@ -1389,20 +1391,34 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1389,20 +1391,34 @@ class DiffusionPipeline(ConfigMixin):
user_agent["custom_pipeline"] = custom_pipeline user_agent["custom_pipeline"] = custom_pipeline
# download all allow_patterns - ignore_patterns # download all allow_patterns - ignore_patterns
cached_folder = snapshot_download( try:
pretrained_model_name, return snapshot_download(
cache_dir=cache_dir, pretrained_model_name,
resume_download=resume_download, cache_dir=cache_dir,
proxies=proxies, resume_download=resume_download,
local_files_only=local_files_only, proxies=proxies,
use_auth_token=use_auth_token, local_files_only=local_files_only,
revision=revision, use_auth_token=use_auth_token,
allow_patterns=allow_patterns, revision=revision,
ignore_patterns=ignore_patterns, allow_patterns=allow_patterns,
user_agent=user_agent, ignore_patterns=ignore_patterns,
) user_agent=user_agent,
)
return cached_folder except FileNotFoundError:
# Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache.
# This can happen in two cases:
# 1. If the user passed `local_files_only=True` => we raise the error directly
# 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error
if model_info_call_error is None:
# 1. user passed `local_files_only=True`
raise
else:
# 2. we forced `local_files_only=True` when `model_info` failed
raise EnvironmentError(
f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occured"
" while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
" above."
) from model_info_call_error
@staticmethod @staticmethod
def _get_signature_keys(obj): def _get_signature_keys(obj):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment