Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
f6a5c359
Unverified
Commit
f6a5c359
authored
Jan 16, 2023
by
Patrick von Platen
Committed by
GitHub
Jan 16, 2023
Browse files
[Community] Fix merger (#2006)
* [Community] Fix merger * finish
parent
651c5adf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
examples/community/checkpoint_merger.py
examples/community/checkpoint_merger.py
+4
-2
No files found.
examples/community/checkpoint_merger.py
View file @
f6a5c359
...
@@ -32,6 +32,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
...
@@ -32,6 +32,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
register_to_config
()
super
().
__init__
()
super
().
__init__
()
def
_compare_model_configs
(
self
,
dict0
,
dict1
):
def
_compare_model_configs
(
self
,
dict0
,
dict1
):
...
@@ -167,6 +168,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
...
@@ -167,6 +168,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
final_pipe
=
DiffusionPipeline
.
from_pretrained
(
final_pipe
=
DiffusionPipeline
.
from_pretrained
(
cached_folders
[
0
],
torch_dtype
=
torch_dtype
,
device_map
=
device_map
cached_folders
[
0
],
torch_dtype
=
torch_dtype
,
device_map
=
device_map
)
)
final_pipe
.
to
(
self
.
device
)
checkpoint_path_2
=
None
checkpoint_path_2
=
None
if
len
(
cached_folders
)
>
2
:
if
len
(
cached_folders
)
>
2
:
...
@@ -202,9 +204,9 @@ class CheckpointMergerPipeline(DiffusionPipeline):
...
@@ -202,9 +204,9 @@ class CheckpointMergerPipeline(DiffusionPipeline):
theta_0
=
theta_0
()
theta_0
=
theta_0
()
update_theta_0
=
getattr
(
module
,
"load_state_dict"
)
update_theta_0
=
getattr
(
module
,
"load_state_dict"
)
theta_1
=
torch
.
load
(
checkpoint_path_1
)
theta_1
=
torch
.
load
(
checkpoint_path_1
,
map_location
=
"cpu"
)
theta_2
=
torch
.
load
(
checkpoint_path_2
)
if
checkpoint_path_2
else
None
theta_2
=
torch
.
load
(
checkpoint_path_2
,
map_location
=
"cpu"
)
if
checkpoint_path_2
else
None
if
not
theta_0
.
keys
()
==
theta_1
.
keys
():
if
not
theta_0
.
keys
()
==
theta_1
.
keys
():
print
(
"SKIPPING ATTR "
,
attr
,
" DUE TO MISMATCH"
)
print
(
"SKIPPING ATTR "
,
attr
,
" DUE TO MISMATCH"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment