Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
0f09b01a
Unverified
Commit
0f09b01a
authored
Jul 17, 2024
by
Sayak Paul
Committed by
GitHub
Jul 17, 2024
Browse files
[Core] fix: shard loading and saving when variant is provided. (#8869)
fix: shard loading and saving when variant is provided.
parent
f6cfe0a1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
1 deletion
+40
-1
src/diffusers/utils/hub_utils.py
src/diffusers/utils/hub_utils.py
+2
-1
tests/models/test_modeling_common.py
tests/models/test_modeling_common.py
+38
-0
No files found.
src/diffusers/utils/hub_utils.py
View file @
0f09b01a
...
...
@@ -271,7 +271,8 @@ if cache_version < 1:
def
_add_variant
(
weights_name
:
str
,
variant
:
Optional
[
str
]
=
None
)
->
str
:
if
variant
is
not
None
:
splits
=
weights_name
.
split
(
"."
)
splits
=
splits
[:
-
1
]
+
[
variant
]
+
splits
[
-
1
:]
split_index
=
-
2
if
weights_name
.
endswith
(
".index.json"
)
else
-
1
splits
=
splits
[:
-
split_index
]
+
[
variant
]
+
splits
[
-
split_index
:]
weights_name
=
"."
.
join
(
splits
)
return
weights_name
...
...
tests/models/test_modeling_common.py
View file @
0f09b01a
...
...
@@ -40,6 +40,7 @@ from diffusers.models.attention_processor import (
)
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
SAFE_WEIGHTS_INDEX_NAME
,
is_torch_npu_available
,
is_xformers_available
,
logging
from
diffusers.utils.hub_utils
import
_add_variant
from
diffusers.utils.testing_utils
import
(
CaptureLogger
,
get_python_version
,
...
...
@@ -915,6 +916,43 @@ class ModelTesterMixin:
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
],
atol
=
1e-5
))
@
require_torch_gpu
def
test_sharded_checkpoints_with_variant
(
self
):
torch
.
manual_seed
(
0
)
config
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
config
).
eval
()
model
=
model
.
to
(
torch_device
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_sizes
(
model
)[
""
]
max_shard_size
=
int
((
model_size
*
0.75
)
/
(
2
**
10
))
# Convert to KB as these test models are small.
variant
=
"fp16"
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
# It doesn't matter if the actual model is in fp16 or not. Just adding the variant and
# testing if loading works with the variant when the checkpoint is sharded should be
# enough.
model
.
cpu
().
save_pretrained
(
tmp_dir
,
max_shard_size
=
f
"
{
max_shard_size
}
KB"
,
variant
=
variant
)
index_filename
=
_add_variant
(
SAFE_WEIGHTS_INDEX_NAME
,
variant
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
index_filename
)))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
expected_num_shards
=
caculate_expected_num_shards
(
os
.
path
.
join
(
tmp_dir
,
index_filename
))
actual_num_shards
=
len
([
file
for
file
in
os
.
listdir
(
tmp_dir
)
if
file
.
endswith
(
".safetensors"
)])
self
.
assertTrue
(
actual_num_shards
==
expected_num_shards
)
new_model
=
self
.
model_class
.
from_pretrained
(
tmp_dir
,
variant
=
variant
).
eval
()
new_model
=
new_model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
if
"generator"
in
inputs_dict
:
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
],
atol
=
1e-5
))
@
require_torch_gpu
def
test_sharded_checkpoints_device_map
(
self
):
config
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment