Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
f6a5c359
Unverified
Commit
f6a5c359
authored
Jan 16, 2023
by
Patrick von Platen
Committed by
GitHub
Jan 16, 2023
Browse files
[Community] Fix merger (#2006)
* [Community] Fix merger * finish
parent
651c5adf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
examples/community/checkpoint_merger.py
examples/community/checkpoint_merger.py
+4
-2
No files found.
examples/community/checkpoint_merger.py
View file @
f6a5c359
...
...
@@ -32,6 +32,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
"""
def
__init__
(
self
):
self
.
register_to_config
()
super
().
__init__
()
def
_compare_model_configs
(
self
,
dict0
,
dict1
):
...
...
@@ -167,6 +168,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
final_pipe
=
DiffusionPipeline
.
from_pretrained
(
cached_folders
[
0
],
torch_dtype
=
torch_dtype
,
device_map
=
device_map
)
final_pipe
.
to
(
self
.
device
)
checkpoint_path_2
=
None
if
len
(
cached_folders
)
>
2
:
...
...
@@ -202,9 +204,9 @@ class CheckpointMergerPipeline(DiffusionPipeline):
theta_0
=
theta_0
()
update_theta_0
=
getattr
(
module
,
"load_state_dict"
)
theta_1
=
torch
.
load
(
checkpoint_path_1
)
theta_1
=
torch
.
load
(
checkpoint_path_1
,
map_location
=
"cpu"
)
theta_2
=
torch
.
load
(
checkpoint_path_2
)
if
checkpoint_path_2
else
None
theta_2
=
torch
.
load
(
checkpoint_path_2
,
map_location
=
"cpu"
)
if
checkpoint_path_2
else
None
if
not
theta_0
.
keys
()
==
theta_1
.
keys
():
print
(
"SKIPPING ATTR "
,
attr
,
" DUE TO MISMATCH"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment