Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
89793a97
Unverified
Commit
89793a97
authored
Aug 25, 2022
by
Anton Lozhkov
Committed by
GitHub
Aug 25, 2022
Browse files
Style the `scripts` directory (#250)
Style scripts
parent
365f7523
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
452 additions
and
315 deletions
+452
-315
Makefile
Makefile
+1
-1
scripts/change_naming_configs_and_checkpoints.py
scripts/change_naming_configs_and_checkpoints.py
+6
-5
scripts/conversion_ldm_uncond.py
scripts/conversion_ldm_uncond.py
+5
-5
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
+208
-136
scripts/convert_ldm_original_checkpoint_to_diffusers.py
scripts/convert_ldm_original_checkpoint_to_diffusers.py
+123
-97
scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
+3
-1
scripts/generate_logits.py
scripts/generate_logits.py
+106
-70
No files found.
Makefile
View file @
89793a97
...
...
@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export
PYTHONPATH
=
src
check_dirs
:=
examples
tes
ts src utils
check_dirs
:=
examples
scrip
ts src
tests
utils
modified_only_fixup
:
$(
eval
modified_py_files :
=
$(
shell
python utils/get_modified_files.py
$(check_dirs)
))
...
...
scripts/change_naming_configs_and_checkpoints.py
View file @
89793a97
...
...
@@ -15,12 +15,15 @@
""" Conversion script for the LDM checkpoints. """
import
argparse
import
os
import
json
import
os
import
torch
from
diffusers
import
UNet2DModel
,
UNet2DConditionModel
from
diffusers
import
UNet2DConditionModel
,
UNet2DModel
from
transformers.file_utils
import
has_file
do_only_config
=
False
do_only_weights
=
True
do_only_renaming
=
False
...
...
@@ -37,9 +40,7 @@ if __name__ == "__main__":
help
=
"The config json file corresponding to the architecture."
,
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
...
...
scripts/conversion_ldm_uncond.py
View file @
89793a97
import
argparse
import
OmegaConf
import
torch
from
diffusers
import
UNetLDMModel
,
VQModel
,
LDMPipeline
,
DDIMScheduler
import
OmegaConf
from
diffusers
import
DDIMScheduler
,
LDMPipeline
,
UNetLDMModel
,
VQModel
def
convert_ldm_original
(
checkpoint_path
,
config_path
,
output_path
):
config
=
OmegaConf
.
load
(
config_path
)
...
...
@@ -53,4 +54,3 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
convert_ldm_original
(
args
.
checkpoint_path
,
args
.
config_path
,
args
.
output_path
)
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
View file @
89793a97
from
diffusers
import
UNet2DModel
,
DDPMScheduler
,
DDPMPipeline
,
VQModel
,
AutoencoderKL
import
argparse
import
json
import
torch
from
diffusers
import
AutoencoderKL
,
DDPMPipeline
,
DDPMScheduler
,
UNet2DModel
,
VQModel
def
shave_segments
(
path
,
n_shave_prefix_segments
=
1
):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if
n_shave_prefix_segments
>=
0
:
return
'.'
.
join
(
path
.
split
(
'.'
)[
n_shave_prefix_segments
:])
return
"."
.
join
(
path
.
split
(
"."
)[
n_shave_prefix_segments
:])
else
:
return
'.'
.
join
(
path
.
split
(
'.'
)[:
n_shave_prefix_segments
])
return
"."
.
join
(
path
.
split
(
"."
)[:
n_shave_prefix_segments
])
def
renew_resnet_paths
(
old_list
,
n_shave_prefix_segments
=
0
):
mapping
=
[]
for
old_item
in
old_list
:
new_item
=
old_item
new_item
=
new_item
.
replace
(
'
block.
'
,
'
resnets.
'
)
new_item
=
new_item
.
replace
(
'
conv_shorcut
'
,
'
conv1
'
)
new_item
=
new_item
.
replace
(
'
nin_shortcut
'
,
'
conv_shortcut
'
)
new_item
=
new_item
.
replace
(
'
temb_proj
'
,
'
time_emb_proj
'
)
new_item
=
new_item
.
replace
(
"
block.
"
,
"
resnets.
"
)
new_item
=
new_item
.
replace
(
"
conv_shorcut
"
,
"
conv1
"
)
new_item
=
new_item
.
replace
(
"
nin_shortcut
"
,
"
conv_shortcut
"
)
new_item
=
new_item
.
replace
(
"
temb_proj
"
,
"
time_emb_proj
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
...
...
@@ -37,21 +39,23 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
# In `model.mid`, the layer is called `attn`.
if
not
in_mid
:
new_item
=
new_item
.
replace
(
'
attn
'
,
'
attentions
'
)
new_item
=
new_item
.
replace
(
'
.k.
'
,
'
.key.
'
)
new_item
=
new_item
.
replace
(
'
.v.
'
,
'
.value.
'
)
new_item
=
new_item
.
replace
(
'
.q.
'
,
'
.query.
'
)
new_item
=
new_item
.
replace
(
"
attn
"
,
"
attentions
"
)
new_item
=
new_item
.
replace
(
"
.k.
"
,
"
.key.
"
)
new_item
=
new_item
.
replace
(
"
.v.
"
,
"
.value.
"
)
new_item
=
new_item
.
replace
(
"
.q.
"
,
"
.query.
"
)
new_item
=
new_item
.
replace
(
'
proj_out
'
,
'
proj_attn
'
)
new_item
=
new_item
.
replace
(
'
norm
'
,
'
group_norm
'
)
new_item
=
new_item
.
replace
(
"
proj_out
"
,
"
proj_attn
"
)
new_item
=
new_item
.
replace
(
"
norm
"
,
"
group_norm
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
assert
isinstance
(
paths
,
list
),
"Paths should be a list of dicts containing 'old' and 'new' keys."
if
attention_paths_to_split
is
not
None
:
...
...
@@ -69,27 +73,27 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
checkpoint
[
path_map
[
'
query
'
]]
=
query
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
'
key
'
]]
=
key
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
'
value
'
]]
=
value
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
"
query
"
]]
=
query
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
"
key
"
]]
=
key
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
"
value
"
]]
=
value
.
reshape
(
target_shape
).
squeeze
()
for
path
in
paths
:
new_path
=
path
[
'
new
'
]
new_path
=
path
[
"
new
"
]
if
attention_paths_to_split
is
not
None
and
new_path
in
attention_paths_to_split
:
continue
new_path
=
new_path
.
replace
(
'
down.
'
,
'
down_blocks.
'
)
new_path
=
new_path
.
replace
(
'
up.
'
,
'
up_blocks.
'
)
new_path
=
new_path
.
replace
(
"
down.
"
,
"
down_blocks.
"
)
new_path
=
new_path
.
replace
(
"
up.
"
,
"
up_blocks.
"
)
if
additional_replacements
is
not
None
:
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'
old
'
],
replacement
[
'
new
'
])
new_path
=
new_path
.
replace
(
replacement
[
"
old
"
],
replacement
[
"
new
"
])
if
'
attentions
'
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]].
squeeze
()
if
"
attentions
"
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]].
squeeze
()
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]]
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]]
def
convert_ddpm_checkpoint
(
checkpoint
,
config
):
...
...
@@ -98,49 +102,63 @@ def convert_ddpm_checkpoint(checkpoint, config):
"""
new_checkpoint
=
{}
new_checkpoint
[
'
time_embedding.linear_1.weight
'
]
=
checkpoint
[
'
temb.dense.0.weight
'
]
new_checkpoint
[
'
time_embedding.linear_1.bias
'
]
=
checkpoint
[
'
temb.dense.0.bias
'
]
new_checkpoint
[
'
time_embedding.linear_2.weight
'
]
=
checkpoint
[
'
temb.dense.1.weight
'
]
new_checkpoint
[
'
time_embedding.linear_2.bias
'
]
=
checkpoint
[
'
temb.dense.1.bias
'
]
new_checkpoint
[
"
time_embedding.linear_1.weight
"
]
=
checkpoint
[
"
temb.dense.0.weight
"
]
new_checkpoint
[
"
time_embedding.linear_1.bias
"
]
=
checkpoint
[
"
temb.dense.0.bias
"
]
new_checkpoint
[
"
time_embedding.linear_2.weight
"
]
=
checkpoint
[
"
temb.dense.1.weight
"
]
new_checkpoint
[
"
time_embedding.linear_2.bias
"
]
=
checkpoint
[
"
temb.dense.1.bias
"
]
new_checkpoint
[
'
conv_norm_out.weight
'
]
=
checkpoint
[
'
norm_out.weight
'
]
new_checkpoint
[
'
conv_norm_out.bias
'
]
=
checkpoint
[
'
norm_out.bias
'
]
new_checkpoint
[
"
conv_norm_out.weight
"
]
=
checkpoint
[
"
norm_out.weight
"
]
new_checkpoint
[
"
conv_norm_out.bias
"
]
=
checkpoint
[
"
norm_out.bias
"
]
new_checkpoint
[
'
conv_in.weight
'
]
=
checkpoint
[
'
conv_in.weight
'
]
new_checkpoint
[
'
conv_in.bias
'
]
=
checkpoint
[
'
conv_in.bias
'
]
new_checkpoint
[
'
conv_out.weight
'
]
=
checkpoint
[
'
conv_out.weight
'
]
new_checkpoint
[
'
conv_out.bias
'
]
=
checkpoint
[
'
conv_out.bias
'
]
new_checkpoint
[
"
conv_in.weight
"
]
=
checkpoint
[
"
conv_in.weight
"
]
new_checkpoint
[
"
conv_in.bias
"
]
=
checkpoint
[
"
conv_in.bias
"
]
new_checkpoint
[
"
conv_out.weight
"
]
=
checkpoint
[
"
conv_out.weight
"
]
new_checkpoint
[
"
conv_out.bias
"
]
=
checkpoint
[
"
conv_out.bias
"
]
num_down_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'down'
in
layer
})
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'down.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_down_blocks
)}
num_down_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"down"
in
layer
})
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"down.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_down_blocks
)
}
num_up_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'
up
'
in
layer
})
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'
up.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
num_up_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"
up
"
in
layer
})
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"
up.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
for
i
in
range
(
num_down_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'layers_per_block'
]
+
1
)
if
any
(
'downsample'
in
layer
for
layer
in
down_blocks
[
i
]):
new_checkpoint
[
f
'down_blocks.
{
i
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.op.weight'
]
new_checkpoint
[
f
'down_blocks.
{
i
}
.downsamplers.0.conv.bias'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.op.bias'
]
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if
any
(
'block'
in
layer
for
layer
in
down_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
block_id
=
(
i
-
1
)
//
(
config
[
"layers_per_block"
]
+
1
)
if
any
(
"downsample"
in
layer
for
layer
in
down_blocks
[
i
]):
new_checkpoint
[
f
"down_blocks.
{
i
}
.downsamplers.0.conv.weight"
]
=
checkpoint
[
f
"down.
{
i
}
.downsample.op.weight"
]
new_checkpoint
[
f
"down_blocks.
{
i
}
.downsamplers.0.conv.bias"
]
=
checkpoint
[
f
"down.
{
i
}
.downsample.op.bias"
]
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if
any
(
"block"
in
layer
for
layer
in
down_blocks
[
i
]):
num_blocks
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
if
any
(
'attn'
in
layer
for
layer
in
down_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"attn"
in
layer
for
layer
in
down_blocks
[
i
]):
num_attn
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
...
...
@@ -150,48 +168,67 @@ def convert_ddpm_checkpoint(checkpoint, config):
# Mid new 2
paths
=
renew_resnet_paths
(
mid_block_1_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_1'
,
'new'
:
'resnets.0'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_1"
,
"new"
:
"resnets.0"
}],
)
paths
=
renew_resnet_paths
(
mid_block_2_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_2'
,
'new'
:
'resnets.1'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_2"
,
"new"
:
"resnets.1"
}],
)
paths
=
renew_attention_paths
(
mid_attn_1_layers
,
in_mid
=
True
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'attn_1'
,
'new'
:
'attentions.0'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"attn_1"
,
"new"
:
"attentions.0"
}],
)
for
i
in
range
(
num_up_blocks
):
block_id
=
num_up_blocks
-
1
-
i
if
any
(
'upsample'
in
layer
for
layer
in
up_blocks
[
i
]):
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.weight'
]
=
checkpoint
[
f
'up.
{
i
}
.upsample.conv.weight'
]
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.bias'
]
=
checkpoint
[
f
'up.
{
i
}
.upsample.conv.bias'
]
if
any
(
"upsample"
in
layer
for
layer
in
up_blocks
[
i
]):
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.weight"
]
=
checkpoint
[
f
"up.
{
i
}
.upsample.conv.weight"
]
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.bias"
]
=
checkpoint
[
f
"up.
{
i
}
.upsample.conv.bias"
]
if
any
(
'block'
in
layer
for
layer
in
up_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"block"
in
layer
for
layer
in
up_blocks
[
i
]):
num_blocks
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
if
any
(
'attn'
in
layer
for
layer
in
up_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"attn"
in
layer
for
layer
in
up_blocks
[
i
]):
num_attn
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
new_checkpoint
=
{
k
.
replace
(
'
mid_new_2
'
,
'
mid_block
'
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
new_checkpoint
=
{
k
.
replace
(
"
mid_new_2
"
,
"
mid_block
"
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
return
new_checkpoint
...
...
@@ -201,50 +238,66 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
"""
new_checkpoint
=
{}
new_checkpoint
[
'
encoder.conv_norm_out.weight
'
]
=
checkpoint
[
'
encoder.norm_out.weight
'
]
new_checkpoint
[
'
encoder.conv_norm_out.bias
'
]
=
checkpoint
[
'
encoder.norm_out.bias
'
]
new_checkpoint
[
"
encoder.conv_norm_out.weight
"
]
=
checkpoint
[
"
encoder.norm_out.weight
"
]
new_checkpoint
[
"
encoder.conv_norm_out.bias
"
]
=
checkpoint
[
"
encoder.norm_out.bias
"
]
new_checkpoint
[
'
encoder.conv_in.weight
'
]
=
checkpoint
[
'
encoder.conv_in.weight
'
]
new_checkpoint
[
'
encoder.conv_in.bias
'
]
=
checkpoint
[
'
encoder.conv_in.bias
'
]
new_checkpoint
[
'
encoder.conv_out.weight
'
]
=
checkpoint
[
'
encoder.conv_out.weight
'
]
new_checkpoint
[
'
encoder.conv_out.bias
'
]
=
checkpoint
[
'
encoder.conv_out.bias
'
]
new_checkpoint
[
"
encoder.conv_in.weight
"
]
=
checkpoint
[
"
encoder.conv_in.weight
"
]
new_checkpoint
[
"
encoder.conv_in.bias
"
]
=
checkpoint
[
"
encoder.conv_in.bias
"
]
new_checkpoint
[
"
encoder.conv_out.weight
"
]
=
checkpoint
[
"
encoder.conv_out.weight
"
]
new_checkpoint
[
"
encoder.conv_out.bias
"
]
=
checkpoint
[
"
encoder.conv_out.bias
"
]
new_checkpoint
[
'
decoder.conv_norm_out.weight
'
]
=
checkpoint
[
'
decoder.norm_out.weight
'
]
new_checkpoint
[
'
decoder.conv_norm_out.bias
'
]
=
checkpoint
[
'
decoder.norm_out.bias
'
]
new_checkpoint
[
"
decoder.conv_norm_out.weight
"
]
=
checkpoint
[
"
decoder.norm_out.weight
"
]
new_checkpoint
[
"
decoder.conv_norm_out.bias
"
]
=
checkpoint
[
"
decoder.norm_out.bias
"
]
new_checkpoint
[
'
decoder.conv_in.weight
'
]
=
checkpoint
[
'
decoder.conv_in.weight
'
]
new_checkpoint
[
'
decoder.conv_in.bias
'
]
=
checkpoint
[
'
decoder.conv_in.bias
'
]
new_checkpoint
[
'
decoder.conv_out.weight
'
]
=
checkpoint
[
'
decoder.conv_out.weight
'
]
new_checkpoint
[
'
decoder.conv_out.bias
'
]
=
checkpoint
[
'
decoder.conv_out.bias
'
]
new_checkpoint
[
"
decoder.conv_in.weight
"
]
=
checkpoint
[
"
decoder.conv_in.weight
"
]
new_checkpoint
[
"
decoder.conv_in.bias
"
]
=
checkpoint
[
"
decoder.conv_in.bias
"
]
new_checkpoint
[
"
decoder.conv_out.weight
"
]
=
checkpoint
[
"
decoder.conv_out.weight
"
]
new_checkpoint
[
"
decoder.conv_out.bias
"
]
=
checkpoint
[
"
decoder.conv_out.bias
"
]
num_down_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
3
])
for
layer
in
checkpoint
if
'down'
in
layer
})
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'down.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_down_blocks
)}
num_down_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
3
])
for
layer
in
checkpoint
if
"down"
in
layer
})
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"down.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_down_blocks
)
}
num_up_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
3
])
for
layer
in
checkpoint
if
'
up
'
in
layer
})
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'
up.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
num_up_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
3
])
for
layer
in
checkpoint
if
"
up
"
in
layer
})
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"
up.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
for
i
in
range
(
num_down_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'layers_per_block'
]
+
1
)
if
any
(
'downsample'
in
layer
for
layer
in
down_blocks
[
i
]):
new_checkpoint
[
f
'encoder.down_blocks.
{
i
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'encoder.down.
{
i
}
.downsample.conv.weight'
]
new_checkpoint
[
f
'encoder.down_blocks.
{
i
}
.downsamplers.0.conv.bias'
]
=
checkpoint
[
f
'encoder.down.
{
i
}
.downsample.conv.bias'
]
if
any
(
'block'
in
layer
for
layer
in
down_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
block_id
=
(
i
-
1
)
//
(
config
[
"layers_per_block"
]
+
1
)
if
any
(
"downsample"
in
layer
for
layer
in
down_blocks
[
i
]):
new_checkpoint
[
f
"encoder.down_blocks.
{
i
}
.downsamplers.0.conv.weight"
]
=
checkpoint
[
f
"encoder.down.
{
i
}
.downsample.conv.weight"
]
new_checkpoint
[
f
"encoder.down_blocks.
{
i
}
.downsamplers.0.conv.bias"
]
=
checkpoint
[
f
"encoder.down.
{
i
}
.downsample.conv.bias"
]
if
any
(
"block"
in
layer
for
layer
in
down_blocks
[
i
]):
num_blocks
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
if
any
(
'attn'
in
layer
for
layer
in
down_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"attn"
in
layer
for
layer
in
down_blocks
[
i
]):
num_attn
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
...
...
@@ -254,48 +307,69 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
# Mid new 2
paths
=
renew_resnet_paths
(
mid_block_1_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_1'
,
'new'
:
'resnets.0'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_1"
,
"new"
:
"resnets.0"
}],
)
paths
=
renew_resnet_paths
(
mid_block_2_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_2'
,
'new'
:
'resnets.1'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_2"
,
"new"
:
"resnets.1"
}],
)
paths
=
renew_attention_paths
(
mid_attn_1_layers
,
in_mid
=
True
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'attn_1'
,
'new'
:
'attentions.0'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"attn_1"
,
"new"
:
"attentions.0"
}],
)
for
i
in
range
(
num_up_blocks
):
block_id
=
num_up_blocks
-
1
-
i
if
any
(
'upsample'
in
layer
for
layer
in
up_blocks
[
i
]):
new_checkpoint
[
f
'decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.weight'
]
=
checkpoint
[
f
'decoder.up.
{
i
}
.upsample.conv.weight'
]
new_checkpoint
[
f
'decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.bias'
]
=
checkpoint
[
f
'decoder.up.
{
i
}
.upsample.conv.bias'
]
if
any
(
'block'
in
layer
for
layer
in
up_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"upsample"
in
layer
for
layer
in
up_blocks
[
i
]):
new_checkpoint
[
f
"decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.weight"
]
=
checkpoint
[
f
"decoder.up.
{
i
}
.upsample.conv.weight"
]
new_checkpoint
[
f
"decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.bias"
]
=
checkpoint
[
f
"decoder.up.
{
i
}
.upsample.conv.bias"
]
if
any
(
"block"
in
layer
for
layer
in
up_blocks
[
i
]):
num_blocks
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
if
any
(
'attn'
in
layer
for
layer
in
up_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"attn"
in
layer
for
layer
in
up_blocks
[
i
]):
num_attn
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
new_checkpoint
=
{
k
.
replace
(
'
mid_new_2
'
,
'
mid_block
'
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
new_checkpoint
=
{
k
.
replace
(
"
mid_new_2
"
,
"
mid_block
"
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
new_checkpoint
[
"quant_conv.weight"
]
=
checkpoint
[
"quant_conv.weight"
]
new_checkpoint
[
"quant_conv.bias"
]
=
checkpoint
[
"quant_conv.bias"
]
if
"quantize.embedding.weight"
in
checkpoint
:
...
...
@@ -321,9 +395,7 @@ if __name__ == "__main__":
help
=
"The config json file corresponding to the architecture."
,
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
...
...
scripts/convert_ldm_original_checkpoint_to_diffusers.py
View file @
89793a97
...
...
@@ -16,8 +16,10 @@
import
argparse
import
json
import
torch
from
diffusers
import
VQModel
,
DDPMScheduler
,
UNet2DModel
,
LDMPipeline
from
diffusers
import
DDPMScheduler
,
LDMPipeline
,
UNet2DModel
,
VQModel
def
shave_segments
(
path
,
n_shave_prefix_segments
=
1
):
...
...
@@ -25,9 +27,9 @@ def shave_segments(path, n_shave_prefix_segments=1):
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if
n_shave_prefix_segments
>=
0
:
return
'.'
.
join
(
path
.
split
(
'.'
)[
n_shave_prefix_segments
:])
return
"."
.
join
(
path
.
split
(
"."
)[
n_shave_prefix_segments
:])
else
:
return
'.'
.
join
(
path
.
split
(
'.'
)[:
n_shave_prefix_segments
])
return
"."
.
join
(
path
.
split
(
"."
)[:
n_shave_prefix_segments
])
def
renew_resnet_paths
(
old_list
,
n_shave_prefix_segments
=
0
):
...
...
@@ -36,18 +38,18 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
mapping
=
[]
for
old_item
in
old_list
:
new_item
=
old_item
.
replace
(
'
in_layers.0
'
,
'
norm1
'
)
new_item
=
new_item
.
replace
(
'
in_layers.2
'
,
'
conv1
'
)
new_item
=
old_item
.
replace
(
"
in_layers.0
"
,
"
norm1
"
)
new_item
=
new_item
.
replace
(
"
in_layers.2
"
,
"
conv1
"
)
new_item
=
new_item
.
replace
(
'
out_layers.0
'
,
'
norm2
'
)
new_item
=
new_item
.
replace
(
'
out_layers.3
'
,
'
conv2
'
)
new_item
=
new_item
.
replace
(
"
out_layers.0
"
,
"
norm2
"
)
new_item
=
new_item
.
replace
(
"
out_layers.3
"
,
"
conv2
"
)
new_item
=
new_item
.
replace
(
'
emb_layers.1
'
,
'
time_emb_proj
'
)
new_item
=
new_item
.
replace
(
'
skip_connection
'
,
'
conv_shortcut
'
)
new_item
=
new_item
.
replace
(
"
emb_layers.1
"
,
"
time_emb_proj
"
)
new_item
=
new_item
.
replace
(
"
skip_connection
"
,
"
conv_shortcut
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
...
...
@@ -60,20 +62,22 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
for
old_item
in
old_list
:
new_item
=
old_item
new_item
=
new_item
.
replace
(
'
norm.weight
'
,
'
group_norm.weight
'
)
new_item
=
new_item
.
replace
(
'
norm.bias
'
,
'
group_norm.bias
'
)
new_item
=
new_item
.
replace
(
"
norm.weight
"
,
"
group_norm.weight
"
)
new_item
=
new_item
.
replace
(
"
norm.bias
"
,
"
group_norm.bias
"
)
new_item
=
new_item
.
replace
(
'
proj_out.weight
'
,
'
proj_attn.weight
'
)
new_item
=
new_item
.
replace
(
'
proj_out.bias
'
,
'
proj_attn.bias
'
)
new_item
=
new_item
.
replace
(
"
proj_out.weight
"
,
"
proj_attn.weight
"
)
new_item
=
new_item
.
replace
(
"
proj_out.bias
"
,
"
proj_attn.bias
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
...
...
@@ -96,31 +100,31 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
checkpoint
[
path_map
[
'
query
'
]]
=
query
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
'
key
'
]]
=
key
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
'
value
'
]]
=
value
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
"
query
"
]]
=
query
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
"
key
"
]]
=
key
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
"
value
"
]]
=
value
.
reshape
(
target_shape
)
for
path
in
paths
:
new_path
=
path
[
'
new
'
]
new_path
=
path
[
"
new
"
]
# These have already been assigned
if
attention_paths_to_split
is
not
None
and
new_path
in
attention_paths_to_split
:
continue
# Global renaming happens here
new_path
=
new_path
.
replace
(
'
middle_block.0
'
,
'
mid.resnets.0
'
)
new_path
=
new_path
.
replace
(
'
middle_block.1
'
,
'
mid.attentions.0
'
)
new_path
=
new_path
.
replace
(
'
middle_block.2
'
,
'
mid.resnets.1
'
)
new_path
=
new_path
.
replace
(
"
middle_block.0
"
,
"
mid.resnets.0
"
)
new_path
=
new_path
.
replace
(
"
middle_block.1
"
,
"
mid.attentions.0
"
)
new_path
=
new_path
.
replace
(
"
middle_block.2
"
,
"
mid.resnets.1
"
)
if
additional_replacements
is
not
None
:
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'
old
'
],
replacement
[
'
new
'
])
new_path
=
new_path
.
replace
(
replacement
[
"
old
"
],
replacement
[
"
new
"
])
# proj_attn.weight has to be converted from conv 1D to linear
if
"proj_attn.weight"
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]][:,
:,
0
]
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]][:,
:,
0
]
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]]
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]]
def
convert_ldm_checkpoint
(
checkpoint
,
config
):
...
...
@@ -129,60 +133,78 @@ def convert_ldm_checkpoint(checkpoint, config):
"""
new_checkpoint
=
{}
new_checkpoint
[
'
time_embedding.linear_1.weight
'
]
=
checkpoint
[
'
time_embed.0.weight
'
]
new_checkpoint
[
'
time_embedding.linear_1.bias
'
]
=
checkpoint
[
'
time_embed.0.bias
'
]
new_checkpoint
[
'
time_embedding.linear_2.weight
'
]
=
checkpoint
[
'
time_embed.2.weight
'
]
new_checkpoint
[
'
time_embedding.linear_2.bias
'
]
=
checkpoint
[
'
time_embed.2.bias
'
]
new_checkpoint
[
"
time_embedding.linear_1.weight
"
]
=
checkpoint
[
"
time_embed.0.weight
"
]
new_checkpoint
[
"
time_embedding.linear_1.bias
"
]
=
checkpoint
[
"
time_embed.0.bias
"
]
new_checkpoint
[
"
time_embedding.linear_2.weight
"
]
=
checkpoint
[
"
time_embed.2.weight
"
]
new_checkpoint
[
"
time_embedding.linear_2.bias
"
]
=
checkpoint
[
"
time_embed.2.bias
"
]
new_checkpoint
[
'
conv_in.weight
'
]
=
checkpoint
[
'
input_blocks.0.0.weight
'
]
new_checkpoint
[
'
conv_in.bias
'
]
=
checkpoint
[
'
input_blocks.0.0.bias
'
]
new_checkpoint
[
"
conv_in.weight
"
]
=
checkpoint
[
"
input_blocks.0.0.weight
"
]
new_checkpoint
[
"
conv_in.bias
"
]
=
checkpoint
[
"
input_blocks.0.0.bias
"
]
new_checkpoint
[
'
conv_norm_out.weight
'
]
=
checkpoint
[
'
out.0.weight
'
]
new_checkpoint
[
'
conv_norm_out.bias
'
]
=
checkpoint
[
'
out.0.bias
'
]
new_checkpoint
[
'
conv_out.weight
'
]
=
checkpoint
[
'
out.2.weight
'
]
new_checkpoint
[
'
conv_out.bias
'
]
=
checkpoint
[
'
out.2.bias
'
]
new_checkpoint
[
"
conv_norm_out.weight
"
]
=
checkpoint
[
"
out.0.weight
"
]
new_checkpoint
[
"
conv_norm_out.bias
"
]
=
checkpoint
[
"
out.0.bias
"
]
new_checkpoint
[
"
conv_out.weight
"
]
=
checkpoint
[
"
out.2.weight
"
]
new_checkpoint
[
"
conv_out.bias
"
]
=
checkpoint
[
"
out.2.bias
"
]
# Retrieves the keys for the input blocks only
num_input_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'input_blocks'
in
layer
})
input_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'input_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_input_blocks
)}
num_input_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"input_blocks"
in
layer
})
input_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"input_blocks.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_input_blocks
)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'middle_block'
in
layer
})
middle_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'middle_block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_middle_blocks
)}
num_middle_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"middle_block"
in
layer
})
middle_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"middle_block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_middle_blocks
)
}
# Retrieves the keys for the output blocks only
num_output_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'output_blocks'
in
layer
})
output_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'output_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_output_blocks
)}
num_output_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"output_blocks"
in
layer
})
output_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"output_blocks.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_output_blocks
)
}
for
i
in
range
(
1
,
num_input_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'
num_res_blocks
'
]
+
1
)
layer_in_block_id
=
(
i
-
1
)
%
(
config
[
'
num_res_blocks
'
]
+
1
)
block_id
=
(
i
-
1
)
//
(
config
[
"
num_res_blocks
"
]
+
1
)
layer_in_block_id
=
(
i
-
1
)
%
(
config
[
"
num_res_blocks
"
]
+
1
)
resnets
=
[
key
for
key
in
input_blocks
[
i
]
if
f
'
input_blocks.
{
i
}
.0
'
in
key
]
attentions
=
[
key
for
key
in
input_blocks
[
i
]
if
f
'
input_blocks.
{
i
}
.1
'
in
key
]
resnets
=
[
key
for
key
in
input_blocks
[
i
]
if
f
"
input_blocks.
{
i
}
.0
"
in
key
]
attentions
=
[
key
for
key
in
input_blocks
[
i
]
if
f
"
input_blocks.
{
i
}
.1
"
in
key
]
if
f
'input_blocks.
{
i
}
.0.op.weight'
in
checkpoint
:
new_checkpoint
[
f
'downsample_blocks.
{
block_id
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'input_blocks.
{
i
}
.0.op.weight'
]
new_checkpoint
[
f
'downsample_blocks.
{
block_id
}
.downsamplers.0.conv.bias'
]
=
checkpoint
[
f
'input_blocks.
{
i
}
.0.op.bias'
]
if
f
"input_blocks.
{
i
}
.0.op.weight"
in
checkpoint
:
new_checkpoint
[
f
"downsample_blocks.
{
block_id
}
.downsamplers.0.conv.weight"
]
=
checkpoint
[
f
"input_blocks.
{
i
}
.0.op.weight"
]
new_checkpoint
[
f
"downsample_blocks.
{
block_id
}
.downsamplers.0.conv.bias"
]
=
checkpoint
[
f
"input_blocks.
{
i
}
.0.op.bias"
]
paths
=
renew_resnet_paths
(
resnets
)
meta_path
=
{
'old'
:
f
'input_blocks.
{
i
}
.0'
,
'new'
:
f
'downsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
resnet_op
=
{
'old'
:
'resnets.2.op'
,
'new'
:
'downsamplers.0.op'
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
,
resnet_op
],
config
=
config
)
meta_path
=
{
"old"
:
f
"input_blocks.
{
i
}
.0"
,
"new"
:
f
"downsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
"
}
resnet_op
=
{
"old"
:
"resnets.2.op"
,
"new"
:
"downsamplers.0.op"
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
,
resnet_op
],
config
=
config
)
if
len
(
attentions
):
paths
=
renew_attention_paths
(
attentions
)
meta_path
=
{
'old'
:
f
'input_blocks.
{
i
}
.1'
,
'new'
:
f
'downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
'
}
meta_path
=
{
"old"
:
f
"input_blocks.
{
i
}
.1"
,
"new"
:
f
"downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
"
,
}
to_split
=
{
f
'
input_blocks.
{
i
}
.1.qkv.bias
'
:
{
'
key
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
'
,
'
query
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
'
,
'
value
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
'
,
f
"
input_blocks.
{
i
}
.1.qkv.bias
"
:
{
"
key
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
"
,
"
query
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
"
,
"
value
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
"
,
},
f
'
input_blocks.
{
i
}
.1.qkv.weight
'
:
{
'
key
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
'
,
'
query
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
'
,
'
value
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
'
,
f
"
input_blocks.
{
i
}
.1.qkv.weight
"
:
{
"
key
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
"
,
"
query
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
"
,
"
value
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
"
,
},
}
assign_to_checkpoint
(
...
...
@@ -191,7 +213,7 @@ def convert_ldm_checkpoint(checkpoint, config):
checkpoint
,
additional_replacements
=
[
meta_path
],
attention_paths_to_split
=
to_split
,
config
=
config
config
=
config
,
)
resnet_0
=
middle_blocks
[
0
]
...
...
@@ -206,46 +228,52 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions_paths
=
renew_attention_paths
(
attentions
)
to_split
=
{
'
middle_block.1.qkv.bias
'
:
{
'
key
'
:
'
mid_block.attentions.0.key.bias
'
,
'
query
'
:
'
mid_block.attentions.0.query.bias
'
,
'
value
'
:
'
mid_block.attentions.0.value.bias
'
,
"
middle_block.1.qkv.bias
"
:
{
"
key
"
:
"
mid_block.attentions.0.key.bias
"
,
"
query
"
:
"
mid_block.attentions.0.query.bias
"
,
"
value
"
:
"
mid_block.attentions.0.value.bias
"
,
},
'
middle_block.1.qkv.weight
'
:
{
'
key
'
:
'
mid_block.attentions.0.key.weight
'
,
'
query
'
:
'
mid_block.attentions.0.query.weight
'
,
'
value
'
:
'
mid_block.attentions.0.value.weight
'
,
"
middle_block.1.qkv.weight
"
:
{
"
key
"
:
"
mid_block.attentions.0.key.weight
"
,
"
query
"
:
"
mid_block.attentions.0.query.weight
"
,
"
value
"
:
"
mid_block.attentions.0.value.weight
"
,
},
}
assign_to_checkpoint
(
attentions_paths
,
new_checkpoint
,
checkpoint
,
attention_paths_to_split
=
to_split
,
config
=
config
)
assign_to_checkpoint
(
attentions_paths
,
new_checkpoint
,
checkpoint
,
attention_paths_to_split
=
to_split
,
config
=
config
)
for
i
in
range
(
num_output_blocks
):
block_id
=
i
//
(
config
[
'
num_res_blocks
'
]
+
1
)
layer_in_block_id
=
i
%
(
config
[
'
num_res_blocks
'
]
+
1
)
block_id
=
i
//
(
config
[
"
num_res_blocks
"
]
+
1
)
layer_in_block_id
=
i
%
(
config
[
"
num_res_blocks
"
]
+
1
)
output_block_layers
=
[
shave_segments
(
name
,
2
)
for
name
in
output_blocks
[
i
]]
output_block_list
=
{}
for
layer
in
output_block_layers
:
layer_id
,
layer_name
=
layer
.
split
(
'.'
)[
0
],
shave_segments
(
layer
,
1
)
layer_id
,
layer_name
=
layer
.
split
(
"."
)[
0
],
shave_segments
(
layer
,
1
)
if
layer_id
in
output_block_list
:
output_block_list
[
layer_id
].
append
(
layer_name
)
else
:
output_block_list
[
layer_id
]
=
[
layer_name
]
if
len
(
output_block_list
)
>
1
:
resnets
=
[
key
for
key
in
output_blocks
[
i
]
if
f
'
output_blocks.
{
i
}
.0
'
in
key
]
attentions
=
[
key
for
key
in
output_blocks
[
i
]
if
f
'
output_blocks.
{
i
}
.1
'
in
key
]
resnets
=
[
key
for
key
in
output_blocks
[
i
]
if
f
"
output_blocks.
{
i
}
.0
"
in
key
]
attentions
=
[
key
for
key
in
output_blocks
[
i
]
if
f
"
output_blocks.
{
i
}
.1
"
in
key
]
resnet_0_paths
=
renew_resnet_paths
(
resnets
)
paths
=
renew_resnet_paths
(
resnets
)
meta_path
=
{
'
old
'
:
f
'
output_blocks.
{
i
}
.0
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
meta_path
=
{
"
old
"
:
f
"
output_blocks.
{
i
}
.0
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
"
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
config
=
config
)
if
[
'conv.weight'
,
'conv.bias'
]
in
output_block_list
.
values
():
index
=
list
(
output_block_list
.
values
()).
index
([
'conv.weight'
,
'conv.bias'
])
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.weight'
]
=
checkpoint
[
f
'output_blocks.
{
i
}
.
{
index
}
.conv.weight'
]
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.bias'
]
=
checkpoint
[
f
'output_blocks.
{
i
}
.
{
index
}
.conv.bias'
]
if
[
"conv.weight"
,
"conv.bias"
]
in
output_block_list
.
values
():
index
=
list
(
output_block_list
.
values
()).
index
([
"conv.weight"
,
"conv.bias"
])
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.weight"
]
=
checkpoint
[
f
"output_blocks.
{
i
}
.
{
index
}
.conv.weight"
]
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.bias"
]
=
checkpoint
[
f
"output_blocks.
{
i
}
.
{
index
}
.conv.bias"
]
# Clear attentions as they have been attributed above.
if
len
(
attentions
)
==
2
:
...
...
@@ -254,19 +282,19 @@ def convert_ldm_checkpoint(checkpoint, config):
if
len
(
attentions
):
paths
=
renew_attention_paths
(
attentions
)
meta_path
=
{
'
old
'
:
f
'
output_blocks.
{
i
}
.1
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
'
"
old
"
:
f
"
output_blocks.
{
i
}
.1
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
"
,
}
to_split
=
{
f
'
output_blocks.
{
i
}
.1.qkv.bias
'
:
{
'
key
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
'
,
'
query
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
'
,
'
value
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
'
,
f
"
output_blocks.
{
i
}
.1.qkv.bias
"
:
{
"
key
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
"
,
"
query
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
"
,
"
value
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
"
,
},
f
'
output_blocks.
{
i
}
.1.qkv.weight
'
:
{
'
key
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
'
,
'
query
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
'
,
'
value
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
'
,
f
"
output_blocks.
{
i
}
.1.qkv.weight
"
:
{
"
key
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
"
,
"
query
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
"
,
"
value
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
"
,
},
}
assign_to_checkpoint
(
...
...
@@ -274,14 +302,14 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
attention_paths_to_split
=
to_split
if
any
(
'
qkv
'
in
key
for
key
in
attentions
)
else
None
,
attention_paths_to_split
=
to_split
if
any
(
"
qkv
"
in
key
for
key
in
attentions
)
else
None
,
config
=
config
,
)
else
:
resnet_0_paths
=
renew_resnet_paths
(
output_block_layers
,
n_shave_prefix_segments
=
1
)
for
path
in
resnet_0_paths
:
old_path
=
'.'
.
join
([
'
output_blocks
'
,
str
(
i
),
path
[
'
old
'
]])
new_path
=
'.'
.
join
([
'
up_blocks
'
,
str
(
block_id
),
'
resnets
'
,
str
(
layer_in_block_id
),
path
[
'
new
'
]])
old_path
=
"."
.
join
([
"
output_blocks
"
,
str
(
i
),
path
[
"
old
"
]])
new_path
=
"."
.
join
([
"
up_blocks
"
,
str
(
block_id
),
"
resnets
"
,
str
(
layer_in_block_id
),
path
[
"
new
"
]])
new_checkpoint
[
new_path
]
=
checkpoint
[
old_path
]
...
...
@@ -303,9 +331,7 @@ if __name__ == "__main__":
help
=
"The config json file corresponding to the architecture."
,
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
...
...
scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
View file @
89793a97
...
...
@@ -16,8 +16,10 @@
import
argparse
import
json
import
torch
from
diffusers
import
UNet2DModel
from
diffusers
import
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
UNet2DModel
def
convert_ncsnpp_checkpoint
(
checkpoint
,
config
):
...
...
scripts/generate_logits.py
View file @
89793a97
from
huggingface_hub
import
HfApi
from
transformers.file_utils
import
has_file
from
diffusers
import
UNet2DModel
import
random
import
torch
from
diffusers
import
UNet2DModel
from
huggingface_hub
import
HfApi
api
=
HfApi
()
results
=
{}
results
[
"google_ddpm_cifar10_32"
]
=
torch
.
tensor
([
-
0.7515
,
-
1.6883
,
0.2420
,
0.0300
,
0.6347
,
1.3433
,
-
1.1743
,
-
3.7467
,
# fmt: off
results
[
"google_ddpm_cifar10_32"
]
=
torch
.
tensor
([
-
0.7515
,
-
1.6883
,
0.2420
,
0.0300
,
0.6347
,
1.3433
,
-
1.1743
,
-
3.7467
,
1.2342
,
-
2.2485
,
0.4636
,
0.8076
,
-
0.7991
,
0.3969
,
0.8498
,
0.9189
,
-
1.8887
,
-
3.3522
,
0.7639
,
0.2040
,
0.6271
,
-
2.7148
,
-
1.6316
,
3.0839
,
0.3186
,
0.2721
,
-
0.9759
,
-
1.2461
,
2.6257
,
1.3557
])
results
[
"google_ddpm_ema_bedroom_256"
]
=
torch
.
tensor
([
-
2.3639
,
-
2.5344
,
0.0054
,
-
0.6674
,
1.5990
,
1.0158
,
0.3124
,
-
2.1436
,
0.3186
,
0.2721
,
-
0.9759
,
-
1.2461
,
2.6257
,
1.3557
])
results
[
"google_ddpm_ema_bedroom_256"
]
=
torch
.
tensor
([
-
2.3639
,
-
2.5344
,
0.0054
,
-
0.6674
,
1.5990
,
1.0158
,
0.3124
,
-
2.1436
,
1.8795
,
-
2.5429
,
-
0.1566
,
-
0.3973
,
1.2490
,
2.6447
,
1.2283
,
-
0.5208
,
-
2.8154
,
-
3.5119
,
2.3838
,
1.2033
,
1.7201
,
-
2.1256
,
-
1.4576
,
2.7948
,
2.4204
,
-
0.9752
,
-
1.2546
,
0.8027
,
3.2758
,
3.1365
])
results
[
"CompVis_ldm_celebahq_256"
]
=
torch
.
tensor
([
-
0.6531
,
-
0.6891
,
-
0.3172
,
-
0.5375
,
-
0.9140
,
-
0.5367
,
-
0.1175
,
-
0.7869
,
2.4204
,
-
0.9752
,
-
1.2546
,
0.8027
,
3.2758
,
3.1365
])
results
[
"CompVis_ldm_celebahq_256"
]
=
torch
.
tensor
([
-
0.6531
,
-
0.6891
,
-
0.3172
,
-
0.5375
,
-
0.9140
,
-
0.5367
,
-
0.1175
,
-
0.7869
,
-
0.3808
,
-
0.4513
,
-
0.2098
,
-
0.0083
,
0.3183
,
0.5140
,
0.2247
,
-
0.1304
,
-
0.1302
,
-
0.2802
,
-
0.2084
,
-
0.2025
,
-
0.4967
,
-
0.4873
,
-
0.0861
,
0.6925
,
0.0250
,
0.1290
,
-
0.1543
,
0.6316
,
1.0460
,
1.4943
])
results
[
"google_ncsnpp_ffhq_1024"
]
=
torch
.
tensor
([
0.0911
,
0.1107
,
0.0182
,
0.0435
,
-
0.0805
,
-
0.0608
,
0.0381
,
0.2172
,
0.0250
,
0.1290
,
-
0.1543
,
0.6316
,
1.0460
,
1.4943
])
results
[
"google_ncsnpp_ffhq_1024"
]
=
torch
.
tensor
([
0.0911
,
0.1107
,
0.0182
,
0.0435
,
-
0.0805
,
-
0.0608
,
0.0381
,
0.2172
,
-
0.0280
,
0.1327
,
-
0.0299
,
-
0.0255
,
-
0.0050
,
-
0.1170
,
-
0.1046
,
0.0309
,
0.1367
,
0.1728
,
-
0.0533
,
-
0.0748
,
-
0.0534
,
0.1624
,
0.0384
,
-
0.1805
,
-
0.0707
,
0.0642
,
0.0220
,
-
0.0134
,
-
0.1333
,
-
0.1505
])
results
[
"google_ncsnpp_bedroom_256"
]
=
torch
.
tensor
([
0.1321
,
0.1337
,
0.0440
,
0.0622
,
-
0.0591
,
-
0.0370
,
0.0503
,
0.2133
,
-
0.0707
,
0.0642
,
0.0220
,
-
0.0134
,
-
0.1333
,
-
0.1505
])
results
[
"google_ncsnpp_bedroom_256"
]
=
torch
.
tensor
([
0.1321
,
0.1337
,
0.0440
,
0.0622
,
-
0.0591
,
-
0.0370
,
0.0503
,
0.2133
,
-
0.0177
,
0.1415
,
-
0.0116
,
-
0.0112
,
0.0044
,
-
0.0980
,
-
0.0789
,
0.0395
,
0.1502
,
0.1785
,
-
0.0488
,
-
0.0514
,
-
0.0404
,
0.1539
,
0.0454
,
-
0.1559
,
-
0.0665
,
0.0659
,
0.0383
,
-
0.0005
,
-
0.1266
,
-
0.1386
])
results
[
"google_ncsnpp_celebahq_256"
]
=
torch
.
tensor
([
0.1154
,
0.1218
,
0.0307
,
0.0526
,
-
0.0711
,
-
0.0541
,
0.0366
,
0.2078
,
-
0.0665
,
0.0659
,
0.0383
,
-
0.0005
,
-
0.1266
,
-
0.1386
])
results
[
"google_ncsnpp_celebahq_256"
]
=
torch
.
tensor
([
0.1154
,
0.1218
,
0.0307
,
0.0526
,
-
0.0711
,
-
0.0541
,
0.0366
,
0.2078
,
-
0.0267
,
0.1317
,
-
0.0226
,
-
0.0193
,
-
0.0014
,
-
0.1055
,
-
0.0902
,
0.0330
,
0.1391
,
0.1709
,
-
0.0562
,
-
0.0693
,
-
0.0560
,
0.1482
,
0.0381
,
-
0.1683
,
-
0.0681
,
0.0661
,
0.0331
,
-
0.0046
,
-
0.1268
,
-
0.1431
])
results
[
"google_ncsnpp_church_256"
]
=
torch
.
tensor
([
0.1192
,
0.1240
,
0.0414
,
0.0606
,
-
0.0557
,
-
0.0412
,
0.0430
,
0.2042
,
-
0.0681
,
0.0661
,
0.0331
,
-
0.0046
,
-
0.1268
,
-
0.1431
])
results
[
"google_ncsnpp_church_256"
]
=
torch
.
tensor
([
0.1192
,
0.1240
,
0.0414
,
0.0606
,
-
0.0557
,
-
0.0412
,
0.0430
,
0.2042
,
-
0.0200
,
0.1385
,
-
0.0115
,
-
0.0132
,
0.0017
,
-
0.0965
,
-
0.0802
,
0.0398
,
0.1433
,
0.1747
,
-
0.0458
,
-
0.0533
,
-
0.0407
,
0.1545
,
0.0419
,
-
0.1574
,
-
0.0645
,
0.0626
,
0.0341
,
-
0.0010
,
-
0.1199
,
-
0.1390
])
results
[
"google_ncsnpp_ffhq_256"
]
=
torch
.
tensor
([
0.1075
,
0.1074
,
0.0205
,
0.0431
,
-
0.0774
,
-
0.0607
,
0.0298
,
0.2042
,
-
0.0645
,
0.0626
,
0.0341
,
-
0.0010
,
-
0.1199
,
-
0.1390
])
results
[
"google_ncsnpp_ffhq_256"
]
=
torch
.
tensor
([
0.1075
,
0.1074
,
0.0205
,
0.0431
,
-
0.0774
,
-
0.0607
,
0.0298
,
0.2042
,
-
0.0320
,
0.1267
,
-
0.0281
,
-
0.0250
,
-
0.0064
,
-
0.1091
,
-
0.0946
,
0.0290
,
0.1328
,
0.1650
,
-
0.0580
,
-
0.0738
,
-
0.0586
,
0.1440
,
0.0337
,
-
0.1746
,
-
0.0712
,
0.0605
,
0.0250
,
-
0.0099
,
-
0.1316
,
-
0.1473
])
results
[
"google_ddpm_cat_256"
]
=
torch
.
tensor
([
-
1.4572
,
-
2.0481
,
-
0.0414
,
-
0.6005
,
1.4136
,
0.5848
,
0.4028
,
-
2.7330
,
-
0.0712
,
0.0605
,
0.0250
,
-
0.0099
,
-
0.1316
,
-
0.1473
])
results
[
"google_ddpm_cat_256"
]
=
torch
.
tensor
([
-
1.4572
,
-
2.0481
,
-
0.0414
,
-
0.6005
,
1.4136
,
0.5848
,
0.4028
,
-
2.7330
,
1.2212
,
-
2.1228
,
0.2155
,
0.4039
,
0.7662
,
2.0535
,
0.7477
,
-
0.3243
,
-
2.1758
,
-
2.7648
,
1.6947
,
0.7026
,
1.2338
,
-
1.6078
,
-
0.8682
,
2.2810
,
1.8574
,
-
0.5718
,
-
0.5586
,
-
0.0186
,
2.3415
,
2.1251
])
results
[
"google_ddpm_celebahq_256"
]
=
torch
.
tensor
([
-
1.3690
,
-
1.9720
,
-
0.4090
,
-
0.6966
,
1.4660
,
0.9938
,
-
0.1385
,
-
2.7324
,
results
[
"google_ddpm_celebahq_256"
]
=
torch
.
tensor
([
-
1.3690
,
-
1.9720
,
-
0.4090
,
-
0.6966
,
1.4660
,
0.9938
,
-
0.1385
,
-
2.7324
,
0.7736
,
-
1.8917
,
0.2923
,
0.4293
,
0.1693
,
1.4112
,
1.1887
,
-
0.3181
,
-
2.2160
,
-
2.6381
,
1.3170
,
0.8163
,
0.9240
,
-
1.6544
,
-
0.6099
,
2.5259
,
1.6430
,
-
0.9090
,
-
0.9392
,
-
0.0126
,
2.4268
,
2.3266
])
results
[
"google_ddpm_ema_celebahq_256"
]
=
torch
.
tensor
([
-
1.3525
,
-
1.9628
,
-
0.3956
,
-
0.6860
,
1.4664
,
1.0014
,
-
0.1259
,
-
2.7212
,
1.6430
,
-
0.9090
,
-
0.9392
,
-
0.0126
,
2.4268
,
2.3266
])
results
[
"google_ddpm_ema_celebahq_256"
]
=
torch
.
tensor
([
-
1.3525
,
-
1.9628
,
-
0.3956
,
-
0.6860
,
1.4664
,
1.0014
,
-
0.1259
,
-
2.7212
,
0.7772
,
-
1.8811
,
0.2996
,
0.4388
,
0.1704
,
1.4029
,
1.1701
,
-
0.3027
,
-
2.2053
,
-
2.6287
,
1.3350
,
0.8131
,
0.9274
,
-
1.6292
,
-
0.6098
,
2.5131
,
1.6505
,
-
0.8958
,
-
0.9298
,
-
0.0151
,
2.4257
,
2.3355
])
results
[
"google_ddpm_church_256"
]
=
torch
.
tensor
([
-
2.0585
,
-
2.7897
,
-
0.2850
,
-
0.8940
,
1.9052
,
0.5702
,
0.6345
,
-
3.8959
,
1.6505
,
-
0.8958
,
-
0.9298
,
-
0.0151
,
2.4257
,
2.3355
])
results
[
"google_ddpm_church_256"
]
=
torch
.
tensor
([
-
2.0585
,
-
2.7897
,
-
0.2850
,
-
0.8940
,
1.9052
,
0.5702
,
0.6345
,
-
3.8959
,
1.5932
,
-
3.2319
,
0.1974
,
0.0287
,
1.7566
,
2.6543
,
0.8387
,
-
0.5351
,
-
3.2736
,
-
4.3375
,
2.9029
,
1.6390
,
1.4640
,
-
2.1701
,
-
1.9013
,
2.9341
,
3.4981
,
-
0.6255
,
-
1.1644
,
-
0.1591
,
3.7097
,
3.2066
])
results
[
"google_ddpm_bedroom_256"
]
=
torch
.
tensor
([
-
2.3139
,
-
2.5594
,
-
0.0197
,
-
0.6785
,
1.7001
,
1.1606
,
0.3075
,
-
2.1740
,
3.4981
,
-
0.6255
,
-
1.1644
,
-
0.1591
,
3.7097
,
3.2066
])
results
[
"google_ddpm_bedroom_256"
]
=
torch
.
tensor
([
-
2.3139
,
-
2.5594
,
-
0.0197
,
-
0.6785
,
1.7001
,
1.1606
,
0.3075
,
-
2.1740
,
1.8071
,
-
2.5630
,
-
0.0926
,
-
0.3811
,
1.2116
,
2.6246
,
1.2731
,
-
0.5398
,
-
2.8153
,
-
3.6140
,
2.3893
,
1.3262
,
1.6258
,
-
2.1856
,
-
1.3267
,
2.8395
,
2.3779
,
-
1.0623
,
-
1.2468
,
0.8959
,
3.3367
,
3.2243
])
results
[
"google_ddpm_ema_church_256"
]
=
torch
.
tensor
([
-
2.0628
,
-
2.7667
,
-
0.2089
,
-
0.8263
,
2.0539
,
0.5992
,
0.6495
,
-
3.8336
,
2.3779
,
-
1.0623
,
-
1.2468
,
0.8959
,
3.3367
,
3.2243
])
results
[
"google_ddpm_ema_church_256"
]
=
torch
.
tensor
([
-
2.0628
,
-
2.7667
,
-
0.2089
,
-
0.8263
,
2.0539
,
0.5992
,
0.6495
,
-
3.8336
,
1.6025
,
-
3.2817
,
0.1721
,
-
0.0633
,
1.7516
,
2.7039
,
0.8100
,
-
0.5908
,
-
3.2113
,
-
4.4343
,
2.9257
,
1.3632
,
1.5562
,
-
2.1489
,
-
1.9894
,
3.0560
,
3.3396
,
-
0.7328
,
-
1.0417
,
0.0383
,
3.7093
,
3.2343
])
results
[
"google_ddpm_ema_cat_256"
]
=
torch
.
tensor
([
-
1.4574
,
-
2.0569
,
-
0.0473
,
-
0.6117
,
1.4018
,
0.5769
,
0.4129
,
-
2.7344
,
3.3396
,
-
0.7328
,
-
1.0417
,
0.0383
,
3.7093
,
3.2343
])
results
[
"google_ddpm_ema_cat_256"
]
=
torch
.
tensor
([
-
1.4574
,
-
2.0569
,
-
0.0473
,
-
0.6117
,
1.4018
,
0.5769
,
0.4129
,
-
2.7344
,
1.2241
,
-
2.1397
,
0.2000
,
0.3937
,
0.7616
,
2.0453
,
0.7324
,
-
0.3391
,
-
2.1746
,
-
2.7744
,
1.6963
,
0.6921
,
1.2187
,
-
1.6172
,
-
0.8877
,
2.2439
,
1.8471
,
-
0.5839
,
-
0.5605
,
-
0.0464
,
2.3250
,
2.1219
])
1.8471
,
-
0.5839
,
-
0.5605
,
-
0.0464
,
2.3250
,
2.1219
])
# fmt: on
models
=
api
.
list_models
(
filter
=
"diffusers"
)
for
mod
in
models
:
...
...
@@ -75,7 +109,7 @@ for mod in models:
print
(
f
"Started running
{
mod
.
modelId
}
!!!"
)
if
mod
.
modelId
.
startswith
(
"CompVis"
):
model
=
UNet2DModel
.
from_pretrained
(
local_checkpoint
,
subfolder
=
"unet"
)
model
=
UNet2DModel
.
from_pretrained
(
local_checkpoint
,
subfolder
=
"unet"
)
else
:
model
=
UNet2DModel
.
from_pretrained
(
local_checkpoint
)
...
...
@@ -85,7 +119,9 @@ for mod in models:
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
sample_size
,
model
.
config
.
sample_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
logits
=
model
(
noise
,
time_step
)[
'
sample
'
]
logits
=
model
(
noise
,
time_step
)[
"
sample
"
]
assert
torch
.
allclose
(
logits
[
0
,
0
,
0
,
:
30
],
results
[
"_"
.
join
(
"_"
.
join
(
mod
.
modelId
.
split
(
"/"
)).
split
(
"-"
))],
atol
=
1e-3
)
assert
torch
.
allclose
(
logits
[
0
,
0
,
0
,
:
30
],
results
[
"_"
.
join
(
"_"
.
join
(
mod
.
modelId
.
split
(
"/"
)).
split
(
"-"
))],
atol
=
1e-3
)
print
(
f
"
{
mod
.
modelId
}
has passed succesfully!!!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment