Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e3ddbe25
Unverified
Commit
e3ddbe25
authored
Feb 16, 2023
by
Damian Stewart
Committed by
GitHub
Feb 16, 2023
Browse files
Fix 3-way merging with the checkpoint_merger community pipeline (#2355)
correctly locate 3rd file; also correct misleading docs
parent
46def726
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
9 deletions
+13
-9
examples/community/checkpoint_merger.py
examples/community/checkpoint_merger.py
+13
-9
No files found.
examples/community/checkpoint_merger.py
View file @
e3ddbe25
...
...
@@ -80,8 +80,8 @@ class CheckpointMergerPipeline(DiffusionPipeline):
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff
erence
" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff
erence
" is supported.
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
...
...
@@ -206,14 +206,18 @@ class CheckpointMergerPipeline(DiffusionPipeline):
)
)
checkpoint_path_1
=
files
[
0
]
if
len
(
files
)
>
0
else
None
if
checkpoint_path_2
is
not
None
and
os
.
path
.
exists
(
checkpoint_path_2
):
files
=
list
(
(
*
glob
.
glob
(
os
.
path
.
join
(
checkpoint_path_2
,
"*.safetensors"
)),
*
glob
.
glob
(
os
.
path
.
join
(
checkpoint_path_2
,
"*.bin"
)),
if
len
(
cached_folders
)
<
3
:
checkpoint_path_2
=
None
else
:
checkpoint_path_2
=
os
.
path
.
join
(
cached_folders
[
2
],
attr
)
if
os
.
path
.
exists
(
checkpoint_path_2
):
files
=
list
(
(
*
glob
.
glob
(
os
.
path
.
join
(
checkpoint_path_2
,
"*.safetensors"
)),
*
glob
.
glob
(
os
.
path
.
join
(
checkpoint_path_2
,
"*.bin"
)),
)
)
)
checkpoint_path_2
=
files
[
0
]
if
len
(
files
)
>
0
else
None
checkpoint_path_2
=
files
[
0
]
if
len
(
files
)
>
0
else
None
# For an attr if both checkpoint_path_1 and 2 are None, ignore.
# If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.
if
checkpoint_path_1
is
None
and
checkpoint_path_2
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment