Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
95a78328
Unverified
Commit
95a78328
authored
Aug 01, 2024
by
YiYi Xu
Committed by
GitHub
Aug 01, 2024
Browse files
fix load sharded checkpoint from a subfolder (local path) (#8913)
fix Co-authored-by:
Sayak Paul
<
spsayakpaul@gmail.com
>
parent
c646fbc1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
24 deletions
+60
-24
src/diffusers/utils/hub_utils.py
src/diffusers/utils/hub_utils.py
+26
-24
tests/models/unets/test_models_unet_2d_condition.py
tests/models/unets/test_models_unet_2d_condition.py
+34
-0
No files found.
src/diffusers/utils/hub_utils.py
View file @
95a78328
...
@@ -448,7 +448,7 @@ def _get_checkpoint_shard_files(
...
@@ -448,7 +448,7 @@ def _get_checkpoint_shard_files(
_check_if_shards_exist_locally
(
_check_if_shards_exist_locally
(
pretrained_model_name_or_path
,
subfolder
=
subfolder
,
original_shard_filenames
=
original_shard_filenames
pretrained_model_name_or_path
,
subfolder
=
subfolder
,
original_shard_filenames
=
original_shard_filenames
)
)
return
pretrained_model_name_or
_path
,
sharded_metadata
return
shards
_path
,
sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
allow_patterns
=
original_shard_filenames
allow_patterns
=
original_shard_filenames
...
@@ -467,35 +467,37 @@ def _get_checkpoint_shard_files(
...
@@ -467,35 +467,37 @@ def _get_checkpoint_shard_files(
"required according to the checkpoint index."
"required according to the checkpoint index."
)
)
try
:
try
:
# Load from URL
# Load from URL
cached_folder
=
snapshot_download
(
cached_folder
=
snapshot_download
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
proxies
=
proxies
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
local_files_only
=
local_files_only
,
token
=
token
,
token
=
token
,
revision
=
revision
,
revision
=
revision
,
allow_patterns
=
allow_patterns
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
ignore_patterns
=
ignore_patterns
,
user_agent
=
user_agent
,
user_agent
=
user_agent
,
)
)
if
subfolder
is
not
None
:
if
subfolder
is
not
None
:
cached_folder
=
os
.
path
.
join
(
cached_folder
,
subfolder
)
cached_folder
=
os
.
path
.
join
(
cached_folder
,
subfolder
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except
HTTPError
as
e
:
except
HTTPError
as
e
:
raise
EnvironmentError
(
raise
EnvironmentError
(
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load
{
pretrained_model_name_or_path
}
. You should try"
f
"We couldn't connect to '
{
HUGGINGFACE_CO_RESOLVE_ENDPOINT
}
' to load
{
pretrained_model_name_or_path
}
. You should try"
" again after checking your internet connection."
" again after checking your internet connection."
)
from
e
)
from
e
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
if
local_files_only
:
el
if
local_files_only
:
_check_if_shards_exist_locally
(
_check_if_shards_exist_locally
(
local_dir
=
cache_dir
,
subfolder
=
subfolder
,
original_shard_filenames
=
original_shard_filenames
local_dir
=
cache_dir
,
subfolder
=
subfolder
,
original_shard_filenames
=
original_shard_filenames
)
)
if
subfolder
is
not
None
:
cached_folder
=
os
.
path
.
join
(
cached_folder
,
subfolder
)
return
cached_folder
,
sharded_metadata
return
cached_folder
,
sharded_metadata
...
...
tests/models/unets/test_models_unet_2d_condition.py
View file @
95a78328
...
@@ -1068,6 +1068,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
...
@@ -1068,6 +1068,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert
loaded_model
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_torch_gpu
def
test_load_sharded_checkpoint_from_hub_local_subfolder
(
self
):
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
ckpt_path
=
snapshot_download
(
"hf-internal-testing/unet2d-sharded-dummy-subfolder"
)
loaded_model
=
self
.
model_class
.
from_pretrained
(
ckpt_path
,
subfolder
=
"unet"
,
local_files_only
=
True
)
loaded_model
=
loaded_model
.
to
(
torch_device
)
new_output
=
loaded_model
(
**
inputs_dict
)
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_torch_gpu
@
require_torch_gpu
def
test_load_sharded_checkpoint_device_map_from_hub
(
self
):
def
test_load_sharded_checkpoint_device_map_from_hub
(
self
):
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
@@ -1077,6 +1088,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
...
@@ -1077,6 +1088,17 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert
loaded_model
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_torch_gpu
def
test_load_sharded_checkpoint_device_map_from_hub_subfolder
(
self
):
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
loaded_model
=
self
.
model_class
.
from_pretrained
(
"hf-internal-testing/unet2d-sharded-dummy-subfolder"
,
subfolder
=
"unet"
,
device_map
=
"auto"
)
new_output
=
loaded_model
(
**
inputs_dict
)
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_torch_gpu
@
require_torch_gpu
def
test_load_sharded_checkpoint_device_map_from_hub_local
(
self
):
def
test_load_sharded_checkpoint_device_map_from_hub_local
(
self
):
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
@@ -1087,6 +1109,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
...
@@ -1087,6 +1109,18 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert
loaded_model
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_torch_gpu
def
test_load_sharded_checkpoint_device_map_from_hub_local_subfolder
(
self
):
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
ckpt_path
=
snapshot_download
(
"hf-internal-testing/unet2d-sharded-dummy-subfolder"
)
loaded_model
=
self
.
model_class
.
from_pretrained
(
ckpt_path
,
local_files_only
=
True
,
subfolder
=
"unet"
,
device_map
=
"auto"
)
new_output
=
loaded_model
(
**
inputs_dict
)
assert
loaded_model
assert
new_output
.
sample
.
shape
==
(
4
,
4
,
16
,
16
)
@
require_peft_backend
@
require_peft_backend
def
test_lora
(
self
):
def
test_lora
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment