Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
95a78328
Unverified
Commit
95a78328
authored
Aug 01, 2024
by
YiYi Xu
Committed by
GitHub
Aug 01, 2024
Browse files
fix load sharded checkpoint from a subfolder (local path) (#8913)
fix Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
c646fbc1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
24 deletions
+60
-24
src/diffusers/utils/hub_utils.py
src/diffusers/utils/hub_utils.py
+26
-24
tests/models/unets/test_models_unet_2d_condition.py
tests/models/unets/test_models_unet_2d_condition.py
+34
-0
No files found.
src/diffusers/utils/hub_utils.py
View file @
95a78328
...
...
@@ -448,7 +448,7 @@ def _get_checkpoint_shard_files(
_check_if_shards_exist_locally
(
pretrained_model_name_or_path
,
subfolder
=
subfolder
,
original_shard_filenames
=
original_shard_filenames
)
return
pretrained_model_name_or
_path
,
sharded_metadata
return
shards
_path
,
sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns
=
original_shard_filenames
...
...
@@ -467,35 +467,37 @@ def _get_checkpoint_shard_files(
"required according to the checkpoint index."
)
try
:
# Load from URL
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
token
=
token
,
revision
=
revision
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
user_agent
=
user_agent
,
)
if
subfolder
is
not
None
:
cached_folder
=
os
.
path
.
join
(
cached_folder
,
subfolder
)
try
:
# Load from URL
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
token
=
token
,
revision
=
revision
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
user_agent
=
user_agent
,
)
if
subfolder
is
not
None
:
cached_folder
=
os
.
path
.
join
(
cached_folder
,
subfolder
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except
HTTPError
as
e
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load
{
pretrained_model_name_or_path
}
. You should try"
" again after checking your internet connection."
)
from
e
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except
HTTPError
as
e
:
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load
{
pretrained_model_name_or_path
}
. You should try"
" again after checking your internet connection."
)
from
e
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
if
local_files_only
:
el
if
local_files_only
:
_check_if_shards_exist_locally
(
local_dir
=
cache_dir
,
subfolder
=
subfolder
,
original_shard_filenames
=
original_shard_filenames
)
if
subfolder
is
not
None
:
cached_folder
=
os
.
path
.
join
(
cached_folder
,
subfolder
)
return
cached_folder
,
sharded_metadata
...
...
tests/models/unets/test_models_unet_2d_condition.py
View file @
95a78328
...
...
@@ -1068,6 +1068,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_torch_gpu
def
test_load_sharded_checkpoint_from_hub_local_subfolder
(
self
):
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
ckpt_path
=
snapshot_download
(
"hf-internal-testing/unet2d-sharded-dummy-subfolder"
)
loaded_model
=
self
.
model_class
.
from_pretrained
(
ckpt_path
,
subfolder
=
"unet"
,
local_files_only
=
True
)
loaded_model
=
loaded_model
.
to
(
torch_device
)
new_output
=
loaded_model
(
**
inputs_dict
)
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_torch_gpu
def
test_load_sharded_checkpoint_device_map_from_hub
(
self
):
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
...
@@ -1077,6 +1088,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_torch_gpu
def
test_load_sharded_checkpoint_device_map_from_hub_subfolder
(
self
):
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
loaded_model
=
self
.
model_class
.
from_pretrained
(
"hf-internal-testing/unet2d-sharded-dummy-subfolder"
,
subfolder
=
"unet"
,
device_map
=
"auto"
)
new_output
=
loaded_model
(
**
inputs_dict
)
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_torch_gpu
def
test_load_sharded_checkpoint_device_map_from_hub_local
(
self
):
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
...
@@ -1087,6 +1109,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_torch_gpu
def
test_load_sharded_checkpoint_device_map_from_hub_local_subfolder
(
self
):
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
ckpt_path
=
snapshot_download
(
"hf-internal-testing/unet2d-sharded-dummy-subfolder"
)
loaded_model
=
self
.
model_class
.
from_pretrained
(
ckpt_path
,
local_files_only
=
True
,
subfolder
=
"unet"
,
device_map
=
"auto"
)
new_output
=
loaded_model
(
**
inputs_dict
)
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_peft_backend
def
test_lora
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment