Unverified Commit 3f7edc5f authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

Fix layer names convert LDM script (#1206)

fix script convert LDM
parent cd77a036
......@@ -112,9 +112,9 @@ def assign_to_checkpoint(
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid.resnets.0")
new_path = new_path.replace("middle_block.1", "mid.attentions.0")
new_path = new_path.replace("middle_block.2", "mid.resnets.1")
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
......@@ -175,15 +175,16 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in checkpoint:
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
f"input_blocks.{i}.0.op.weight"
]
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
f"input_blocks.{i}.0.op.bias"
]
continue
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"downsample_blocks.{block_id}.resnets.{layer_in_block_id}"}
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
assign_to_checkpoint(
paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
......@@ -193,18 +194,18 @@ def convert_ldm_checkpoint(checkpoint, config):
paths = renew_attention_paths(attentions)
meta_path = {
"old": f"input_blocks.{i}.1",
"new": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}",
"new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",
}
to_split = {
f"input_blocks.{i}.1.qkv.bias": {
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
"key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
"query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
"value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
},
f"input_blocks.{i}.1.qkv.weight": {
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
"key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
"query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
"value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
},
}
assign_to_checkpoint(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment