Unverified Commit e3ddbe25 authored by Damian Stewart's avatar Damian Stewart Committed by GitHub
Browse files

Fix 3-way merging with the checkpoint_merger community pipeline (#2355)

correctly locate 3rd file; also correct misleading docs
parent 46def726
...@@ -80,8 +80,8 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -80,8 +80,8 @@ class CheckpointMergerPipeline(DiffusionPipeline):
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None. interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_diff" and None.
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_diff" is supported.
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False. force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
...@@ -206,7 +206,11 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -206,7 +206,11 @@ class CheckpointMergerPipeline(DiffusionPipeline):
) )
) )
checkpoint_path_1 = files[0] if len(files) > 0 else None checkpoint_path_1 = files[0] if len(files) > 0 else None
if checkpoint_path_2 is not None and os.path.exists(checkpoint_path_2): if len(cached_folders) < 3:
checkpoint_path_2 = None
else:
checkpoint_path_2 = os.path.join(cached_folders[2], attr)
if os.path.exists(checkpoint_path_2):
files = list( files = list(
( (
*glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")), *glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment