Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
0f09b01a
Unverified
Commit
0f09b01a
authored
Jul 17, 2024
by
Sayak Paul
Committed by
GitHub
Jul 17, 2024
Browse files
[Core] fix: shard loading and saving when variant is provided. (#8869)
fix: shard loading and saving when variant is provided.
parent
f6cfe0a1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
1 deletion
+40
-1
src/diffusers/utils/hub_utils.py
src/diffusers/utils/hub_utils.py
+2
-1
tests/models/test_modeling_common.py
tests/models/test_modeling_common.py
+38
-0
No files found.
src/diffusers/utils/hub_utils.py
View file @
0f09b01a
...
...
@@ -271,7 +271,8 @@ if cache_version < 1:
def
_add_variant
(
weights_name
:
str
,
variant
:
Optional
[
str
]
=
None
)
->
str
:
if
variant
is
not
None
:
splits
=
weights_name
.
split
(
"."
)
splits
=
splits
[:
-
1
]
+
[
variant
]
+
splits
[
-
1
:]
split_index
=
-
2
if
weights_name
.
endswith
(
".index.json"
)
else
-
1
splits
=
splits
[:
-
split_index
]
+
[
variant
]
+
splits
[
-
split_index
:]
weights_name
=
"."
.
join
(
splits
)
return
weights_name
...
...
tests/models/test_modeling_common.py
View file @
0f09b01a
...
...
@@ -40,6 +40,7 @@ from diffusers.models.attention_processor import (
)
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
SAFE_WEIGHTS_INDEX_NAME
,
is_torch_npu_available
,
is_xformers_available
,
logging
from
diffusers.utils.hub_utils
import
_add_variant
from
diffusers.utils.testing_utils
import
(
CaptureLogger
,
get_python_version
,
...
...
@@ -915,6 +916,43 @@ class ModelTesterMixin:
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
],
atol
=
1e-5
))
@
require_torch_gpu
def
test_sharded_checkpoints_with_variant
(
self
):
torch
.
manual_seed
(
0
)
config
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
config
).
eval
()
model
=
model
.
to
(
torch_device
)
base_output
=
model
(
**
inputs_dict
)
model_size
=
compute_module_sizes
(
model
)[
""
]
max_shard_size
=
int
((
model_size
*
0.75
)
/
(
2
**
10
))
# Convert to KB as these test models are small.
variant
=
"fp16"
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
# It doesn't matter if the actual model is in fp16 or not. Just adding the variant and
# testing if loading works with the variant when the checkpoint is sharded should be
# enough.
model
.
cpu
().
save_pretrained
(
tmp_dir
,
max_shard_size
=
f
"
{
max_shard_size
}
KB"
,
variant
=
variant
)
index_filename
=
_add_variant
(
SAFE_WEIGHTS_INDEX_NAME
,
variant
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
tmp_dir
,
index_filename
)))
# Now check if the right number of shards exists. First, let's get the number of shards.
# Since this number can be dependent on the model being tested, it's important that we calculate it
# instead of hardcoding it.
expected_num_shards
=
caculate_expected_num_shards
(
os
.
path
.
join
(
tmp_dir
,
index_filename
))
actual_num_shards
=
len
([
file
for
file
in
os
.
listdir
(
tmp_dir
)
if
file
.
endswith
(
".safetensors"
)])
self
.
assertTrue
(
actual_num_shards
==
expected_num_shards
)
new_model
=
self
.
model_class
.
from_pretrained
(
tmp_dir
,
variant
=
variant
).
eval
()
new_model
=
new_model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
if
"generator"
in
inputs_dict
:
_
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
new_output
=
new_model
(
**
inputs_dict
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
],
atol
=
1e-5
))
@
require_torch_gpu
def
test_sharded_checkpoints_device_map
(
self
):
config
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment