"vscode:/vscode.git/clone" did not exist on "cabad2f7a295cbfc0f8e0710ffd69a0c79bd3d8d"
Unverified Commit 7533e3d7 authored by Lincoln Stein's avatar Lincoln Stein Committed by GitHub
Browse files

[Feat] checkpoint_merger works on local models as well as ones that use safetensors (#2060)



* allow a local model directory to be used for merging

* moved checkpoint merge bugfix into main for testing

* possibly fix local variable "config_dict" referenced before assignment

* fix deprecation warning

* debugging...

* debugging

* allow safetensors

* safetensors try again

* fix syntax error

* further debugging

* fix logic error when checkpoint 2 is none

* more debugging...

* more debuging...

* more debugging...

* more debugging...

* debugging

* clean up status reporting

* skip the requires_safety_checker boolean

* moved checkpoint merge bugfix into main for testing

* possibly fix local variable "config_dict" referenced before assignment

* fix deprecation warning

* allow safetensors

* fix logic error when checkpoint 2 is none

* clean up status reporting

* undo hack to use private repo for community pipelines

* allow a local model directory to be used for merging

* allow safetensors

* clean up status reporting

* reformatted with black

* sort imported modules correctly

* Update examples/community/checkpoint_merger.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update examples/community/checkpoint_merger.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update examples/community/checkpoint_merger.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix import style error
Co-authored-by: default avatarLincoln Stein <lstein@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 41833109
......@@ -4,6 +4,12 @@ from typing import Dict, List, Union
import torch
from diffusers.utils import is_safetensors_available
if is_safetensors_available():
import safetensors.torch
from diffusers import DiffusionPipeline, __version__
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
......@@ -93,7 +99,8 @@ class CheckpointMergerPipeline(DiffusionPipeline):
alpha = kwargs.pop("alpha", 0.5)
interp = kwargs.pop("interp", None)
print("Recieved list", pretrained_model_name_or_path_list)
print("Received list", pretrained_model_name_or_path_list)
print(f"Combining with alpha={alpha}, interpolation mode={interp}")
checkpoint_count = len(pretrained_model_name_or_path_list)
# Ignore result from model_index_json comparision of the two checkpoints
......@@ -114,8 +121,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
# Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_'
config_dicts = []
for pretrained_model_name_or_path in pretrained_model_name_or_path_list:
if not os.path.isdir(pretrained_model_name_or_path):
config_dict = DiffusionPipeline.get_config_dict(
config_dict = DiffusionPipeline.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
......@@ -149,7 +155,10 @@ class CheckpointMergerPipeline(DiffusionPipeline):
requested_pipeline_class = config_dict.get("_class_name")
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
cached_folder = snapshot_download(
cached_folder = (
pretrained_model_name_or_path
if os.path.isdir(pretrained_model_name_or_path)
else snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
......@@ -160,11 +169,12 @@ class CheckpointMergerPipeline(DiffusionPipeline):
allow_patterns=allow_patterns,
user_agent=user_agent,
)
)
print("Cached Folder", cached_folder)
cached_folders.append(cached_folder)
# Step 3:-
# Load the first checkpoint as a diffusion pipeline and modify it's module state_dict in place
# Load the first checkpoint as a diffusion pipeline and modify its module state_dict in place
final_pipe = DiffusionPipeline.from_pretrained(
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
)
......@@ -188,38 +198,56 @@ class CheckpointMergerPipeline(DiffusionPipeline):
if not attr.startswith("_"):
checkpoint_path_1 = os.path.join(cached_folders[1], attr)
if os.path.exists(checkpoint_path_1):
files = glob.glob(os.path.join(checkpoint_path_1, "*.bin"))
files = list(
(
*glob.glob(os.path.join(checkpoint_path_1, "*.safetensors")),
*glob.glob(os.path.join(checkpoint_path_1, "*.bin")),
)
)
checkpoint_path_1 = files[0] if len(files) > 0 else None
if checkpoint_path_2 is not None and os.path.exists(checkpoint_path_2):
files = glob.glob(os.path.join(checkpoint_path_2, "*.bin"))
files = list(
(
*glob.glob(os.path.join(checkpoint_path_2, "*.safetensors")),
*glob.glob(os.path.join(checkpoint_path_2, "*.bin")),
)
)
checkpoint_path_2 = files[0] if len(files) > 0 else None
# For an attr if both checkpoint_path_1 and 2 are None, ignore.
# If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.
if checkpoint_path_1 is None and checkpoint_path_2 is None:
print("SKIPPING ATTR ", attr)
print(f"Skipping {attr}: not present in 2nd or 3d model")
continue
try:
module = getattr(final_pipe, attr)
if isinstance(module, bool): # ignore requires_safety_checker boolean
continue
theta_0 = getattr(module, "state_dict")
theta_0 = theta_0()
update_theta_0 = getattr(module, "load_state_dict")
theta_1 = torch.load(checkpoint_path_1, map_location="cpu")
theta_2 = torch.load(checkpoint_path_2, map_location="cpu") if checkpoint_path_2 else None
theta_1 = (
safetensors.torch.load_file(checkpoint_path_1)
if (is_safetensors_available() and checkpoint_path_1.endswith(".safetensors"))
else torch.load(checkpoint_path_1, map_location="cpu")
)
theta_2 = None
if checkpoint_path_2:
theta_2 = (
safetensors.torch.load_file(checkpoint_path_2)
if (is_safetensors_available() and checkpoint_path_2.endswith(".safetensors"))
else torch.load(checkpoint_path_2, map_location="cpu")
)
if not theta_0.keys() == theta_1.keys():
print("SKIPPING ATTR ", attr, " DUE TO MISMATCH")
print(f"Skipping {attr}: key mismatch")
continue
if theta_2 and not theta_1.keys() == theta_2.keys():
print("SKIPPING ATTR ", attr, " DUE TO MISMATCH")
except:
print("SKIPPING ATTR ", attr)
print(f"Skipping {attr}:y mismatch")
except Exception as e:
print(f"Skipping {attr} do to an unexpected error: {str(e)}")
continue
print("Found dicts for")
print(attr)
print(checkpoint_path_1)
print(checkpoint_path_2)
print(f"MERGING {attr}")
for key in theta_0.keys():
if theta_2:
......@@ -232,8 +260,6 @@ class CheckpointMergerPipeline(DiffusionPipeline):
update_theta_0(theta_0)
del theta_0
print("Diffusion pipeline successfully updated with merged weights")
return final_pipe
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment