Commit 3f0b44b3 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

improve ddpm conversion script

parent cb90fd69
from diffusers import UNetUnconditionalModel
from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline
import argparse
import json
import torch
......@@ -56,7 +56,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
if attention_paths_to_split is not None:
if config is None:
raise ValueError(f"Please specify the config if setting 'attention_paths_to_split' to 'True'.")
raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
......@@ -86,7 +86,6 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
for replacement in additional_replacements:
new_path = new_path.replace(replacement['old'], replacement['new'])
if 'attentions' in new_path:
checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
else:
......@@ -97,7 +96,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
new_checkpoint = {}
new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight']
......@@ -121,7 +119,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
for i in range(num_downsample_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1)
layer_in_block_id = (i - 1) % (config['num_res_blocks'] + 1)
if any('downsample' in layer for layer in downsample_blocks[i]):
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
......@@ -138,7 +135,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint)
if any('attn' in layer for layer in downsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
......@@ -148,7 +144,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
......@@ -186,7 +181,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
if any('attn' in layer for layer in upsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
......@@ -220,12 +214,21 @@ if __name__ == "__main__":
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
)
args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_path)
with open(args.config_file) as f:
config = json.loads(f.read())
converted_checkpoint = convert_ddpm_checkpoint(args.checkpoint_path, args.config_file)
torch.save(converted_checkpoint, args.dump_path)
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
if "ddpm" in config:
del config["ddpm"]
model = UNetUnconditionalModel(**config)
model.load_state_dict(converted_checkpoint)
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
pipe = DDPMPipeline(unet=model, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment