Unverified Commit 3f7edc5f authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

Fix layer names convert LDM script (#1206)

fix script convert LDM
parent cd77a036
...@@ -112,9 +112,9 @@ def assign_to_checkpoint( ...@@ -112,9 +112,9 @@ def assign_to_checkpoint(
continue continue
# Global renaming happens here # Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid.resnets.0") new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid.attentions.0") new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid.resnets.1") new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None: if additional_replacements is not None:
for replacement in additional_replacements: for replacement in additional_replacements:
...@@ -175,15 +175,16 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -175,15 +175,16 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in checkpoint: if f"input_blocks.{i}.0.op.weight" in checkpoint:
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
f"input_blocks.{i}.0.op.weight" f"input_blocks.{i}.0.op.weight"
] ]
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
f"input_blocks.{i}.0.op.bias" f"input_blocks.{i}.0.op.bias"
] ]
continue
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"downsample_blocks.{block_id}.resnets.{layer_in_block_id}"} meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"} resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
assign_to_checkpoint( assign_to_checkpoint(
paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
...@@ -193,18 +194,18 @@ def convert_ldm_checkpoint(checkpoint, config): ...@@ -193,18 +194,18 @@ def convert_ldm_checkpoint(checkpoint, config):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
meta_path = { meta_path = {
"old": f"input_blocks.{i}.1", "old": f"input_blocks.{i}.1",
"new": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",
} }
to_split = { to_split = {
f"input_blocks.{i}.1.qkv.bias": { f"input_blocks.{i}.1.qkv.bias": {
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias", "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias", "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias", "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
}, },
f"input_blocks.{i}.1.qkv.weight": { f"input_blocks.{i}.1.qkv.weight": {
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight", "key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight", "query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight", "value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
}, },
} }
assign_to_checkpoint( assign_to_checkpoint(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment