Unverified Commit f6a5c359 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Community] Fix merger (#2006)

* [Community] Fix merger

* finish
parent 651c5adf
...@@ -32,6 +32,7 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -32,6 +32,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
""" """
def __init__(self): def __init__(self):
self.register_to_config()
super().__init__() super().__init__()
def _compare_model_configs(self, dict0, dict1): def _compare_model_configs(self, dict0, dict1):
...@@ -167,6 +168,7 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -167,6 +168,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
final_pipe = DiffusionPipeline.from_pretrained( final_pipe = DiffusionPipeline.from_pretrained(
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
) )
final_pipe.to(self.device)
checkpoint_path_2 = None checkpoint_path_2 = None
if len(cached_folders) > 2: if len(cached_folders) > 2:
...@@ -202,9 +204,9 @@ class CheckpointMergerPipeline(DiffusionPipeline): ...@@ -202,9 +204,9 @@ class CheckpointMergerPipeline(DiffusionPipeline):
theta_0 = theta_0() theta_0 = theta_0()
update_theta_0 = getattr(module, "load_state_dict") update_theta_0 = getattr(module, "load_state_dict")
theta_1 = torch.load(checkpoint_path_1) theta_1 = torch.load(checkpoint_path_1, map_location="cpu")
theta_2 = torch.load(checkpoint_path_2) if checkpoint_path_2 else None theta_2 = torch.load(checkpoint_path_2, map_location="cpu") if checkpoint_path_2 else None
if not theta_0.keys() == theta_1.keys(): if not theta_0.keys() == theta_1.keys():
print("SKIPPING ATTR ", attr, " DUE TO MISMATCH") print("SKIPPING ATTR ", attr, " DUE TO MISMATCH")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment