Unverified Commit f6a5c359 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Community] Fix merger (#2006)

* [Community] Fix merger

* finish
parent 651c5adf
......@@ -32,6 +32,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
"""
def __init__(self):
self.register_to_config()
super().__init__()
def _compare_model_configs(self, dict0, dict1):
......@@ -167,6 +168,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
final_pipe = DiffusionPipeline.from_pretrained(
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
)
final_pipe.to(self.device)
checkpoint_path_2 = None
if len(cached_folders) > 2:
......@@ -202,9 +204,9 @@ class CheckpointMergerPipeline(DiffusionPipeline):
theta_0 = theta_0()
update_theta_0 = getattr(module, "load_state_dict")
theta_1 = torch.load(checkpoint_path_1)
theta_1 = torch.load(checkpoint_path_1, map_location="cpu")
theta_2 = torch.load(checkpoint_path_2) if checkpoint_path_2 else None
theta_2 = torch.load(checkpoint_path_2, map_location="cpu") if checkpoint_path_2 else None
if not theta_0.keys() == theta_1.keys():
print("SKIPPING ATTR ", attr, " DUE TO MISMATCH")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment