Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
074798b2
Unverified
Commit
074798b2
authored
Dec 19, 2024
by
hlky
Committed by
GitHub
Dec 19, 2024
Browse files
Fix `local_files_only` for checkpoints with shards (#10294)
parent
3ee96695
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
38 deletions
+29
-38
src/diffusers/utils/hub_utils.py
src/diffusers/utils/hub_utils.py
+29
-38
No files found.
src/diffusers/utils/hub_utils.py
View file @
074798b2
...
...
@@ -455,48 +455,39 @@ def _get_checkpoint_shard_files(
allow_patterns
=
[
os
.
path
.
join
(
subfolder
,
p
)
for
p
in
allow_patterns
]
ignore_patterns
=
[
"*.json"
,
"*.md"
]
if
not
local_files_only
:
# `model_info` call must guarded with the above condition.
model_files_info
=
model_info
(
pretrained_model_name_or_path
,
revision
=
revision
,
token
=
token
)
for
shard_file
in
original_shard_filenames
:
shard_file_present
=
any
(
shard_file
in
k
.
rfilename
for
k
in
model_files_info
.
siblings
)
if
not
shard_file_present
:
raise
EnvironmentError
(
f
"
{
shards_path
}
does not appear to have a file named
{
shard_file
}
which is "
"required according to the checkpoint index."
)
try
:
# Load from URL
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
token
=
token
,
revision
=
revision
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
user_agent
=
user_agent
,
)
if
subfolder
is
not
None
:
cached_folder
=
os
.
path
.
join
(
cached_folder
,
subfolder
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except
HTTPError
as
e
:
# `model_info` call must guarded with the above condition.
model_files_info
=
model_info
(
pretrained_model_name_or_path
,
revision
=
revision
,
token
=
token
)
for
shard_file
in
original_shard_filenames
:
shard_file_present
=
any
(
shard_file
in
k
.
rfilename
for
k
in
model_files_info
.
siblings
)
if
not
shard_file_present
:
raise
EnvironmentError
(
f
"
We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load
{
pretrained_model_name_or_path
}
. You should try
"
"
again after checking your internet connection
."
)
from
e
f
"
{
shards_path
}
does not appear to have a file named
{
shard_file
}
which is
"
"
required according to the checkpoint index
."
)
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
elif
local_files_only
:
_check_if_shards_exist_locally
(
local_dir
=
cache_dir
,
subfolder
=
subfolder
,
original_shard_filenames
=
original_shard_filenames
try
:
# Load from URL
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
token
=
token
,
revision
=
revision
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
user_agent
=
user_agent
,
)
if
subfolder
is
not
None
:
cached_folder
=
os
.
path
.
join
(
cache_dir
,
subfolder
)
cached_folder
=
os
.
path
.
join
(
cached_folder
,
subfolder
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except
HTTPError
as
e
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load
{
pretrained_model_name_or_path
}
. You should try"
" again after checking your internet connection."
)
from
e
return
cached_folder
,
sharded_metadata
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment