Commit 3f1e9592 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

Fix conversion script

parent 87060e6a
#!/usr/bin/env python3
import json
import os
from diffusers import UNetUnconditionalModel
from scripts.convert_ldm_original_checkpoint_to_diffusers import convert_ldm_checkpoint
from huggingface_hub import hf_hub_download
import torch
model_id = "fusing/latent-diffusion-celeba-256"
subfolder = "unet"
#model_id = "fusing/unet-ldm-dummy"
#subfolder = None
checkpoint = "diffusion_model.pt"
config = "config.json"
if subfolder is not None:
checkpoint = os.path.join(subfolder, checkpoint)
config = os.path.join(subfolder, config)
original_checkpoint = torch.load(hf_hub_download(model_id, checkpoint))
config_path = hf_hub_download(model_id, config)
with open(config_path) as f:
config = json.load(f)
checkpoint = convert_ldm_checkpoint(original_checkpoint, config)
def current_codebase_conversion():
model = UNetUnconditionalModel.from_pretrained(model_id, subfolder=subfolder, ldm=True)
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step)
return model.state_dict()
currently_converted_checkpoint = current_codebase_conversion()
torch.save(currently_converted_checkpoint, 'currently_converted_checkpoint.pt')
def diff_between_checkpoints(ch_0, ch_1):
all_layers_included = False
if not set(ch_0.keys()) == set(ch_1.keys()):
print(f"Contained in ch_0 and not in ch_1 (Total: {len((set(ch_0.keys()) - set(ch_1.keys())))})")
for key in sorted(list((set(ch_0.keys()) - set(ch_1.keys())))):
print(f"\t{key}")
print(f"Contained in ch_1 and not in ch_0 (Total: {len((set(ch_1.keys()) - set(ch_0.keys())))})")
for key in sorted(list((set(ch_1.keys()) - set(ch_0.keys())))):
print(f"\t{key}")
else:
print("Keys are the same between the two checkpoints")
all_layers_included = True
keys = ch_0.keys()
non_equal_keys = []
if all_layers_included:
for key in keys:
try:
if not torch.allclose(ch_0[key].cpu(), ch_1[key].cpu()):
non_equal_keys.append(f'{key}. Diff: {torch.max(torch.abs(ch_0[key].cpu() - ch_1[key].cpu()))}')
except RuntimeError as e:
print(e)
non_equal_keys.append(f'{key}. Diff in shape: {ch_0[key].size()} vs {ch_1[key].size()}')
if len(non_equal_keys):
non_equal_keys = '\n\t'.join(non_equal_keys)
print(f"These keys do not satisfy equivalence requirement:\n\t{non_equal_keys}")
else:
print("All keys are equal across checkpoints.")
diff_between_checkpoints(currently_converted_checkpoint, checkpoint)
...@@ -72,7 +72,7 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0): ...@@ -72,7 +72,7 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
return mapping return mapping
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None): def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
""" """
This does the final conversion step: take locally converted weights and apply a global renaming This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements to them. It splits attention layers, and takes into account additional replacements
...@@ -85,11 +85,19 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s ...@@ -85,11 +85,19 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
# Splits the attention layers into three variables. # Splits the attention layers into three variables.
if attention_paths_to_split is not None: if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items(): for path, path_map in attention_paths_to_split.items():
query, key, value = torch.split(old_checkpoint[path], int(old_checkpoint[path].shape[0] / 3)) old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
checkpoint[path_map['query']] = query target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
checkpoint[path_map['key']] = key
checkpoint[path_map['value']] = value num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map['query']] = query.reshape(target_shape)
checkpoint[path_map['key']] = key.reshape(target_shape)
checkpoint[path_map['value']] = value.reshape(target_shape)
for path in paths: for path in paths:
new_path = path['new'] new_path = path['new']
...@@ -107,7 +115,11 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s ...@@ -107,7 +115,11 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
for replacement in additional_replacements: for replacement in additional_replacements:
new_path = new_path.replace(replacement['old'], replacement['new']) new_path = new_path.replace(replacement['old'], replacement['new'])
checkpoint[new_path] = old_checkpoint[path['old']] # proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path['old']]
def convert_ldm_checkpoint(checkpoint, config): def convert_ldm_checkpoint(checkpoint, config):
...@@ -155,7 +167,7 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -155,7 +167,7 @@ def convert_ldm_checkpoint(checkpoint, config):
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'} meta_path = {'old': f'input_blocks.{i}.0', 'new': f'downsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'} resnet_op = {'old': 'resnets.2.op', 'new': 'downsamplers.0.op'}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config)
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
...@@ -177,19 +189,19 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -177,19 +189,19 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint, new_checkpoint,
checkpoint, checkpoint,
additional_replacements=[meta_path], additional_replacements=[meta_path],
attention_paths_to_split=to_split attention_paths_to_split=to_split,
config=config
) )
resnet_0 = middle_blocks[0] resnet_0 = middle_blocks[0]
attentions = middle_blocks[1] attentions = middle_blocks[1]
resnet_1 = middle_blocks[2] resnet_1 = middle_blocks[2]
resnet_0_paths = renew_resnet_paths(resnet_0) resnet_0_paths = renew_resnet_paths(resnet_0)
assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint) assign_to_checkpoint(resnet_0_paths, new_checkpoint, checkpoint, config=config)
resnet_1_paths = renew_resnet_paths(resnet_1) resnet_1_paths = renew_resnet_paths(resnet_1)
assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint) assign_to_checkpoint(resnet_1_paths, new_checkpoint, checkpoint, config=config)
attentions_paths = renew_attention_paths(attentions) attentions_paths = renew_attention_paths(attentions)
to_split = { to_split = {
...@@ -204,7 +216,7 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -204,7 +216,7 @@ def convert_ldm_checkpoint(checkpoint, config):
'value': 'mid.attentions.0.value.weight', 'value': 'mid.attentions.0.value.weight',
}, },
} }
assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split) assign_to_checkpoint(attentions_paths, new_checkpoint, checkpoint, attention_paths_to_split=to_split, config=config)
for i in range(num_output_blocks): for i in range(num_output_blocks):
block_id = i // (config['num_res_blocks'] + 1) block_id = i // (config['num_res_blocks'] + 1)
...@@ -227,7 +239,7 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -227,7 +239,7 @@ def convert_ldm_checkpoint(checkpoint, config):
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'} meta_path = {'old': f'output_blocks.{i}.0', 'new': f'upsample_blocks.{block_id}.resnets.{layer_in_block_id}'}
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path]) assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[meta_path], config=config)
if ['conv.weight', 'conv.bias'] in output_block_list.values(): if ['conv.weight', 'conv.bias'] in output_block_list.values():
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias']) index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
...@@ -238,7 +250,6 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -238,7 +250,6 @@ def convert_ldm_checkpoint(checkpoint, config):
if len(attentions) == 2: if len(attentions) == 2:
attentions = [] attentions = []
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
meta_path = { meta_path = {
...@@ -262,7 +273,8 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -262,7 +273,8 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint, new_checkpoint,
checkpoint, checkpoint,
additional_replacements=[meta_path], additional_replacements=[meta_path],
attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None attention_paths_to_split=to_split if any('qkv' in key for key in attentions) else None,
config=config,
) )
else: else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
...@@ -296,7 +308,6 @@ if __name__ == "__main__": ...@@ -296,7 +308,6 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_path) checkpoint = torch.load(args.checkpoint_path)
with open(args.config_file) as f: with open(args.config_file) as f:
...@@ -304,6 +315,3 @@ if __name__ == "__main__": ...@@ -304,6 +315,3 @@ if __name__ == "__main__":
converted_checkpoint = convert_ldm_checkpoint(checkpoint, config) converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)
torch.save(checkpoint, args.dump_path) torch.save(checkpoint, args.dump_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment