Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
89793a97
Unverified
Commit
89793a97
authored
Aug 25, 2022
by
Anton Lozhkov
Committed by
GitHub
Aug 25, 2022
Browse files
Style the `scripts` directory (#250)
Style scripts
parent
365f7523
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
452 additions
and
315 deletions
+452
-315
Makefile
Makefile
+1
-1
scripts/change_naming_configs_and_checkpoints.py
scripts/change_naming_configs_and_checkpoints.py
+6
-5
scripts/conversion_ldm_uncond.py
scripts/conversion_ldm_uncond.py
+5
-5
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
+208
-136
scripts/convert_ldm_original_checkpoint_to_diffusers.py
scripts/convert_ldm_original_checkpoint_to_diffusers.py
+123
-97
scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
+3
-1
scripts/generate_logits.py
scripts/generate_logits.py
+106
-70
No files found.
Makefile
View file @
89793a97
...
...
@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export
PYTHONPATH
=
src
check_dirs
:=
examples
tes
ts src utils
check_dirs
:=
examples
scrip
ts src
tests
utils
modified_only_fixup
:
$(
eval
modified_py_files :
=
$(
shell
python utils/get_modified_files.py
$(check_dirs)
))
...
...
scripts/change_naming_configs_and_checkpoints.py
View file @
89793a97
...
...
@@ -15,12 +15,15 @@
""" Conversion script for the LDM checkpoints. """
import
argparse
import
os
import
json
import
os
import
torch
from
diffusers
import
UNet2DModel
,
UNet2DConditionModel
from
diffusers
import
UNet2DConditionModel
,
UNet2DModel
from
transformers.file_utils
import
has_file
do_only_config
=
False
do_only_weights
=
True
do_only_renaming
=
False
...
...
@@ -37,9 +40,7 @@ if __name__ == "__main__":
help
=
"The config json file corresponding to the architecture."
,
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
...
...
scripts/conversion_ldm_uncond.py
View file @
89793a97
import
argparse
import
OmegaConf
import
torch
from
diffusers
import
UNetLDMModel
,
VQModel
,
LDMPipeline
,
DDIMScheduler
import
OmegaConf
from
diffusers
import
DDIMScheduler
,
LDMPipeline
,
UNetLDMModel
,
VQModel
def
convert_ldm_original
(
checkpoint_path
,
config_path
,
output_path
):
config
=
OmegaConf
.
load
(
config_path
)
...
...
@@ -16,14 +17,14 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
for
key
in
keys
:
if
key
.
startswith
(
first_stage_key
):
first_stage_dict
[
key
.
replace
(
first_stage_key
,
""
)]
=
state_dict
[
key
]
# extract state_dict for UNetLDM
unet_state_dict
=
{}
unet_key
=
"model.diffusion_model."
for
key
in
keys
:
if
key
.
startswith
(
unet_key
):
unet_state_dict
[
key
.
replace
(
unet_key
,
""
)]
=
state_dict
[
key
]
vqvae_init_args
=
config
.
model
.
params
.
first_stage_config
.
params
unet_init_args
=
config
.
model
.
params
.
unet_config
.
params
...
...
@@ -53,4 +54,3 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
convert_ldm_original
(
args
.
checkpoint_path
,
args
.
config_path
,
args
.
output_path
)
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
View file @
89793a97
from
diffusers
import
UNet2DModel
,
DDPMScheduler
,
DDPMPipeline
,
VQModel
,
AutoencoderKL
import
argparse
import
json
import
torch
from
diffusers
import
AutoencoderKL
,
DDPMPipeline
,
DDPMScheduler
,
UNet2DModel
,
VQModel
def
shave_segments
(
path
,
n_shave_prefix_segments
=
1
):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if
n_shave_prefix_segments
>=
0
:
return
'.'
.
join
(
path
.
split
(
'.'
)[
n_shave_prefix_segments
:])
return
"."
.
join
(
path
.
split
(
"."
)[
n_shave_prefix_segments
:])
else
:
return
'.'
.
join
(
path
.
split
(
'.'
)[:
n_shave_prefix_segments
])
return
"."
.
join
(
path
.
split
(
"."
)[:
n_shave_prefix_segments
])
def
renew_resnet_paths
(
old_list
,
n_shave_prefix_segments
=
0
):
mapping
=
[]
for
old_item
in
old_list
:
new_item
=
old_item
new_item
=
new_item
.
replace
(
'
block.
'
,
'
resnets.
'
)
new_item
=
new_item
.
replace
(
'
conv_shorcut
'
,
'
conv1
'
)
new_item
=
new_item
.
replace
(
'
nin_shortcut
'
,
'
conv_shortcut
'
)
new_item
=
new_item
.
replace
(
'
temb_proj
'
,
'
time_emb_proj
'
)
new_item
=
new_item
.
replace
(
"
block.
"
,
"
resnets.
"
)
new_item
=
new_item
.
replace
(
"
conv_shorcut
"
,
"
conv1
"
)
new_item
=
new_item
.
replace
(
"
nin_shortcut
"
,
"
conv_shortcut
"
)
new_item
=
new_item
.
replace
(
"
temb_proj
"
,
"
time_emb_proj
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
...
...
@@ -37,21 +39,23 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
# In `model.mid`, the layer is called `attn`.
if
not
in_mid
:
new_item
=
new_item
.
replace
(
'
attn
'
,
'
attentions
'
)
new_item
=
new_item
.
replace
(
'
.k.
'
,
'
.key.
'
)
new_item
=
new_item
.
replace
(
'
.v.
'
,
'
.value.
'
)
new_item
=
new_item
.
replace
(
'
.q.
'
,
'
.query.
'
)
new_item
=
new_item
.
replace
(
"
attn
"
,
"
attentions
"
)
new_item
=
new_item
.
replace
(
"
.k.
"
,
"
.key.
"
)
new_item
=
new_item
.
replace
(
"
.v.
"
,
"
.value.
"
)
new_item
=
new_item
.
replace
(
"
.q.
"
,
"
.query.
"
)
new_item
=
new_item
.
replace
(
'
proj_out
'
,
'
proj_attn
'
)
new_item
=
new_item
.
replace
(
'
norm
'
,
'
group_norm
'
)
new_item
=
new_item
.
replace
(
"
proj_out
"
,
"
proj_attn
"
)
new_item
=
new_item
.
replace
(
"
norm
"
,
"
group_norm
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
assert
isinstance
(
paths
,
list
),
"Paths should be a list of dicts containing 'old' and 'new' keys."
if
attention_paths_to_split
is
not
None
:
...
...
@@ -69,27 +73,27 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
checkpoint
[
path_map
[
'
query
'
]]
=
query
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
'
key
'
]]
=
key
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
'
value
'
]]
=
value
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
"
query
"
]]
=
query
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
"
key
"
]]
=
key
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
"
value
"
]]
=
value
.
reshape
(
target_shape
).
squeeze
()
for
path
in
paths
:
new_path
=
path
[
'
new
'
]
new_path
=
path
[
"
new
"
]
if
attention_paths_to_split
is
not
None
and
new_path
in
attention_paths_to_split
:
continue
new_path
=
new_path
.
replace
(
'
down.
'
,
'
down_blocks.
'
)
new_path
=
new_path
.
replace
(
'
up.
'
,
'
up_blocks.
'
)
new_path
=
new_path
.
replace
(
"
down.
"
,
"
down_blocks.
"
)
new_path
=
new_path
.
replace
(
"
up.
"
,
"
up_blocks.
"
)
if
additional_replacements
is
not
None
:
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'
old
'
],
replacement
[
'
new
'
])
new_path
=
new_path
.
replace
(
replacement
[
"
old
"
],
replacement
[
"
new
"
])
if
'
attentions
'
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]].
squeeze
()
if
"
attentions
"
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]].
squeeze
()
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]]
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]]
def
convert_ddpm_checkpoint
(
checkpoint
,
config
):
...
...
@@ -98,49 +102,63 @@ def convert_ddpm_checkpoint(checkpoint, config):
"""
new_checkpoint
=
{}
new_checkpoint
[
'
time_embedding.linear_1.weight
'
]
=
checkpoint
[
'
temb.dense.0.weight
'
]
new_checkpoint
[
'
time_embedding.linear_1.bias
'
]
=
checkpoint
[
'
temb.dense.0.bias
'
]
new_checkpoint
[
'
time_embedding.linear_2.weight
'
]
=
checkpoint
[
'
temb.dense.1.weight
'
]
new_checkpoint
[
'
time_embedding.linear_2.bias
'
]
=
checkpoint
[
'
temb.dense.1.bias
'
]
new_checkpoint
[
"
time_embedding.linear_1.weight
"
]
=
checkpoint
[
"
temb.dense.0.weight
"
]
new_checkpoint
[
"
time_embedding.linear_1.bias
"
]
=
checkpoint
[
"
temb.dense.0.bias
"
]
new_checkpoint
[
"
time_embedding.linear_2.weight
"
]
=
checkpoint
[
"
temb.dense.1.weight
"
]
new_checkpoint
[
"
time_embedding.linear_2.bias
"
]
=
checkpoint
[
"
temb.dense.1.bias
"
]
new_checkpoint
[
'
conv_norm_out.weight
'
]
=
checkpoint
[
'
norm_out.weight
'
]
new_checkpoint
[
'
conv_norm_out.bias
'
]
=
checkpoint
[
'
norm_out.bias
'
]
new_checkpoint
[
"
conv_norm_out.weight
"
]
=
checkpoint
[
"
norm_out.weight
"
]
new_checkpoint
[
"
conv_norm_out.bias
"
]
=
checkpoint
[
"
norm_out.bias
"
]
new_checkpoint
[
'
conv_in.weight
'
]
=
checkpoint
[
'
conv_in.weight
'
]
new_checkpoint
[
'
conv_in.bias
'
]
=
checkpoint
[
'
conv_in.bias
'
]
new_checkpoint
[
'
conv_out.weight
'
]
=
checkpoint
[
'
conv_out.weight
'
]
new_checkpoint
[
'
conv_out.bias
'
]
=
checkpoint
[
'
conv_out.bias
'
]
new_checkpoint
[
"
conv_in.weight
"
]
=
checkpoint
[
"
conv_in.weight
"
]
new_checkpoint
[
"
conv_in.bias
"
]
=
checkpoint
[
"
conv_in.bias
"
]
new_checkpoint
[
"
conv_out.weight
"
]
=
checkpoint
[
"
conv_out.weight
"
]
new_checkpoint
[
"
conv_out.bias
"
]
=
checkpoint
[
"
conv_out.bias
"
]
num_down_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'down'
in
layer
})
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'down.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_down_blocks
)}
num_down_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"down"
in
layer
})
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"down.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_down_blocks
)
}
num_up_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'
up
'
in
layer
})
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'
up.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
num_up_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"
up
"
in
layer
})
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"
up.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
for
i
in
range
(
num_down_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'layers_per_block'
]
+
1
)
if
any
(
'downsample'
in
layer
for
layer
in
down_blocks
[
i
]):
new_checkpoint
[
f
'down_blocks.
{
i
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.op.weight'
]
new_checkpoint
[
f
'down_blocks.
{
i
}
.downsamplers.0.conv.bias'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.op.bias'
]
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if
any
(
'block'
in
layer
for
layer
in
down_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
block_id
=
(
i
-
1
)
//
(
config
[
"layers_per_block"
]
+
1
)
if
any
(
"downsample"
in
layer
for
layer
in
down_blocks
[
i
]):
new_checkpoint
[
f
"down_blocks.
{
i
}
.downsamplers.0.conv.weight"
]
=
checkpoint
[
f
"down.
{
i
}
.downsample.op.weight"
]
new_checkpoint
[
f
"down_blocks.
{
i
}
.downsamplers.0.conv.bias"
]
=
checkpoint
[
f
"down.
{
i
}
.downsample.op.bias"
]
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
if
any
(
"block"
in
layer
for
layer
in
down_blocks
[
i
]):
num_blocks
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
if
any
(
'attn'
in
layer
for
layer
in
down_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"attn"
in
layer
for
layer
in
down_blocks
[
i
]):
num_attn
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
...
...
@@ -150,48 +168,67 @@ def convert_ddpm_checkpoint(checkpoint, config):
# Mid new 2
paths
=
renew_resnet_paths
(
mid_block_1_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_1'
,
'new'
:
'resnets.0'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_1"
,
"new"
:
"resnets.0"
}],
)
paths
=
renew_resnet_paths
(
mid_block_2_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_2'
,
'new'
:
'resnets.1'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_2"
,
"new"
:
"resnets.1"
}],
)
paths
=
renew_attention_paths
(
mid_attn_1_layers
,
in_mid
=
True
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'attn_1'
,
'new'
:
'attentions.0'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"attn_1"
,
"new"
:
"attentions.0"
}],
)
for
i
in
range
(
num_up_blocks
):
block_id
=
num_up_blocks
-
1
-
i
if
any
(
'upsample'
in
layer
for
layer
in
up_blocks
[
i
]):
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.weight'
]
=
checkpoint
[
f
'up.
{
i
}
.upsample.conv.weight'
]
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.bias'
]
=
checkpoint
[
f
'up.
{
i
}
.upsample.conv.bias'
]
if
any
(
"upsample"
in
layer
for
layer
in
up_blocks
[
i
]):
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.weight"
]
=
checkpoint
[
f
"up.
{
i
}
.upsample.conv.weight"
]
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.bias"
]
=
checkpoint
[
f
"up.
{
i
}
.upsample.conv.bias"
]
if
any
(
'block'
in
layer
for
layer
in
up_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"block"
in
layer
for
layer
in
up_blocks
[
i
]):
num_blocks
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
if
any
(
'attn'
in
layer
for
layer
in
up_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"attn"
in
layer
for
layer
in
up_blocks
[
i
]):
num_attn
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
new_checkpoint
=
{
k
.
replace
(
'
mid_new_2
'
,
'
mid_block
'
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
new_checkpoint
=
{
k
.
replace
(
"
mid_new_2
"
,
"
mid_block
"
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
return
new_checkpoint
...
...
@@ -201,50 +238,66 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
"""
new_checkpoint
=
{}
new_checkpoint
[
'
encoder.conv_norm_out.weight
'
]
=
checkpoint
[
'
encoder.norm_out.weight
'
]
new_checkpoint
[
'
encoder.conv_norm_out.bias
'
]
=
checkpoint
[
'
encoder.norm_out.bias
'
]
new_checkpoint
[
"
encoder.conv_norm_out.weight
"
]
=
checkpoint
[
"
encoder.norm_out.weight
"
]
new_checkpoint
[
"
encoder.conv_norm_out.bias
"
]
=
checkpoint
[
"
encoder.norm_out.bias
"
]
new_checkpoint
[
'
encoder.conv_in.weight
'
]
=
checkpoint
[
'
encoder.conv_in.weight
'
]
new_checkpoint
[
'
encoder.conv_in.bias
'
]
=
checkpoint
[
'
encoder.conv_in.bias
'
]
new_checkpoint
[
'
encoder.conv_out.weight
'
]
=
checkpoint
[
'
encoder.conv_out.weight
'
]
new_checkpoint
[
'
encoder.conv_out.bias
'
]
=
checkpoint
[
'
encoder.conv_out.bias
'
]
new_checkpoint
[
"
encoder.conv_in.weight
"
]
=
checkpoint
[
"
encoder.conv_in.weight
"
]
new_checkpoint
[
"
encoder.conv_in.bias
"
]
=
checkpoint
[
"
encoder.conv_in.bias
"
]
new_checkpoint
[
"
encoder.conv_out.weight
"
]
=
checkpoint
[
"
encoder.conv_out.weight
"
]
new_checkpoint
[
"
encoder.conv_out.bias
"
]
=
checkpoint
[
"
encoder.conv_out.bias
"
]
new_checkpoint
[
'
decoder.conv_norm_out.weight
'
]
=
checkpoint
[
'
decoder.norm_out.weight
'
]
new_checkpoint
[
'
decoder.conv_norm_out.bias
'
]
=
checkpoint
[
'
decoder.norm_out.bias
'
]
new_checkpoint
[
"
decoder.conv_norm_out.weight
"
]
=
checkpoint
[
"
decoder.norm_out.weight
"
]
new_checkpoint
[
"
decoder.conv_norm_out.bias
"
]
=
checkpoint
[
"
decoder.norm_out.bias
"
]
new_checkpoint
[
'
decoder.conv_in.weight
'
]
=
checkpoint
[
'
decoder.conv_in.weight
'
]
new_checkpoint
[
'
decoder.conv_in.bias
'
]
=
checkpoint
[
'
decoder.conv_in.bias
'
]
new_checkpoint
[
'
decoder.conv_out.weight
'
]
=
checkpoint
[
'
decoder.conv_out.weight
'
]
new_checkpoint
[
'
decoder.conv_out.bias
'
]
=
checkpoint
[
'
decoder.conv_out.bias
'
]
new_checkpoint
[
"
decoder.conv_in.weight
"
]
=
checkpoint
[
"
decoder.conv_in.weight
"
]
new_checkpoint
[
"
decoder.conv_in.bias
"
]
=
checkpoint
[
"
decoder.conv_in.bias
"
]
new_checkpoint
[
"
decoder.conv_out.weight
"
]
=
checkpoint
[
"
decoder.conv_out.weight
"
]
new_checkpoint
[
"
decoder.conv_out.bias
"
]
=
checkpoint
[
"
decoder.conv_out.bias
"
]
num_down_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
3
])
for
layer
in
checkpoint
if
'down'
in
layer
})
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'down.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_down_blocks
)}
num_down_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
3
])
for
layer
in
checkpoint
if
"down"
in
layer
})
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"down.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_down_blocks
)
}
num_up_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
3
])
for
layer
in
checkpoint
if
'
up
'
in
layer
})
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'
up.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
num_up_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
3
])
for
layer
in
checkpoint
if
"
up
"
in
layer
})
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"
up.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
for
i
in
range
(
num_down_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'layers_per_block'
]
+
1
)
if
any
(
'downsample'
in
layer
for
layer
in
down_blocks
[
i
]):
new_checkpoint
[
f
'encoder.down_blocks.
{
i
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'encoder.down.
{
i
}
.downsample.conv.weight'
]
new_checkpoint
[
f
'encoder.down_blocks.
{
i
}
.downsamplers.0.conv.bias'
]
=
checkpoint
[
f
'encoder.down.
{
i
}
.downsample.conv.bias'
]
if
any
(
'block'
in
layer
for
layer
in
down_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
block_id
=
(
i
-
1
)
//
(
config
[
"layers_per_block"
]
+
1
)
if
any
(
"downsample"
in
layer
for
layer
in
down_blocks
[
i
]):
new_checkpoint
[
f
"encoder.down_blocks.
{
i
}
.downsamplers.0.conv.weight"
]
=
checkpoint
[
f
"encoder.down.
{
i
}
.downsample.conv.weight"
]
new_checkpoint
[
f
"encoder.down_blocks.
{
i
}
.downsamplers.0.conv.bias"
]
=
checkpoint
[
f
"encoder.down.
{
i
}
.downsample.conv.bias"
]
if
any
(
"block"
in
layer
for
layer
in
down_blocks
[
i
]):
num_blocks
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
if
any
(
'attn'
in
layer
for
layer
in
down_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"attn"
in
layer
for
layer
in
down_blocks
[
i
]):
num_attn
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
...
...
@@ -254,48 +307,69 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
# Mid new 2
paths
=
renew_resnet_paths
(
mid_block_1_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_1'
,
'new'
:
'resnets.0'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_1"
,
"new"
:
"resnets.0"
}],
)
paths
=
renew_resnet_paths
(
mid_block_2_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_2'
,
'new'
:
'resnets.1'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_2"
,
"new"
:
"resnets.1"
}],
)
paths
=
renew_attention_paths
(
mid_attn_1_layers
,
in_mid
=
True
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'attn_1'
,
'new'
:
'attentions.0'
}
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"attn_1"
,
"new"
:
"attentions.0"
}],
)
for
i
in
range
(
num_up_blocks
):
block_id
=
num_up_blocks
-
1
-
i
if
any
(
'upsample'
in
layer
for
layer
in
up_blocks
[
i
]):
new_checkpoint
[
f
'decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.weight'
]
=
checkpoint
[
f
'decoder.up.
{
i
}
.upsample.conv.weight'
]
new_checkpoint
[
f
'decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.bias'
]
=
checkpoint
[
f
'decoder.up.
{
i
}
.upsample.conv.bias'
]
if
any
(
'block'
in
layer
for
layer
in
up_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"upsample"
in
layer
for
layer
in
up_blocks
[
i
]):
new_checkpoint
[
f
"decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.weight"
]
=
checkpoint
[
f
"decoder.up.
{
i
}
.upsample.conv.weight"
]
new_checkpoint
[
f
"decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.bias"
]
=
checkpoint
[
f
"decoder.up.
{
i
}
.upsample.conv.bias"
]
if
any
(
"block"
in
layer
for
layer
in
up_blocks
[
i
]):
num_blocks
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
if
any
(
'attn'
in
layer
for
layer
in
up_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"attn"
in
layer
for
layer
in
up_blocks
[
i
]):
num_attn
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
new_checkpoint
=
{
k
.
replace
(
'
mid_new_2
'
,
'
mid_block
'
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
new_checkpoint
=
{
k
.
replace
(
"
mid_new_2
"
,
"
mid_block
"
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
new_checkpoint
[
"quant_conv.weight"
]
=
checkpoint
[
"quant_conv.weight"
]
new_checkpoint
[
"quant_conv.bias"
]
=
checkpoint
[
"quant_conv.bias"
]
if
"quantize.embedding.weight"
in
checkpoint
:
...
...
@@ -321,9 +395,7 @@ if __name__ == "__main__":
help
=
"The config json file corresponding to the architecture."
,
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
...
...
scripts/convert_ldm_original_checkpoint_to_diffusers.py
View file @
89793a97
...
...
@@ -16,8 +16,10 @@
import
argparse
import
json
import
torch
from
diffusers
import
VQModel
,
DDPMScheduler
,
UNet2DModel
,
LDMPipeline
from
diffusers
import
DDPMScheduler
,
LDMPipeline
,
UNet2DModel
,
VQModel
def
shave_segments
(
path
,
n_shave_prefix_segments
=
1
):
...
...
@@ -25,9 +27,9 @@ def shave_segments(path, n_shave_prefix_segments=1):
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if
n_shave_prefix_segments
>=
0
:
return
'.'
.
join
(
path
.
split
(
'.'
)[
n_shave_prefix_segments
:])
return
"."
.
join
(
path
.
split
(
"."
)[
n_shave_prefix_segments
:])
else
:
return
'.'
.
join
(
path
.
split
(
'.'
)[:
n_shave_prefix_segments
])
return
"."
.
join
(
path
.
split
(
"."
)[:
n_shave_prefix_segments
])
def
renew_resnet_paths
(
old_list
,
n_shave_prefix_segments
=
0
):
...
...
@@ -36,18 +38,18 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
mapping
=
[]
for
old_item
in
old_list
:
new_item
=
old_item
.
replace
(
'
in_layers.0
'
,
'
norm1
'
)
new_item
=
new_item
.
replace
(
'
in_layers.2
'
,
'
conv1
'
)
new_item
=
old_item
.
replace
(
"
in_layers.0
"
,
"
norm1
"
)
new_item
=
new_item
.
replace
(
"
in_layers.2
"
,
"
conv1
"
)
new_item
=
new_item
.
replace
(
'
out_layers.0
'
,
'
norm2
'
)
new_item
=
new_item
.
replace
(
'
out_layers.3
'
,
'
conv2
'
)
new_item
=
new_item
.
replace
(
"
out_layers.0
"
,
"
norm2
"
)
new_item
=
new_item
.
replace
(
"
out_layers.3
"
,
"
conv2
"
)
new_item
=
new_item
.
replace
(
'
emb_layers.1
'
,
'
time_emb_proj
'
)
new_item
=
new_item
.
replace
(
'
skip_connection
'
,
'
conv_shortcut
'
)
new_item
=
new_item
.
replace
(
"
emb_layers.1
"
,
"
time_emb_proj
"
)
new_item
=
new_item
.
replace
(
"
skip_connection
"
,
"
conv_shortcut
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
...
...
@@ -60,20 +62,22 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
for
old_item
in
old_list
:
new_item
=
old_item
new_item
=
new_item
.
replace
(
'
norm.weight
'
,
'
group_norm.weight
'
)
new_item
=
new_item
.
replace
(
'
norm.bias
'
,
'
group_norm.bias
'
)
new_item
=
new_item
.
replace
(
"
norm.weight
"
,
"
group_norm.weight
"
)
new_item
=
new_item
.
replace
(
"
norm.bias
"
,
"
group_norm.bias
"
)
new_item
=
new_item
.
replace
(
'
proj_out.weight
'
,
'
proj_attn.weight
'
)
new_item
=
new_item
.
replace
(
'
proj_out.bias
'
,
'
proj_attn.bias
'
)
new_item
=
new_item
.
replace
(
"
proj_out.weight
"
,
"
proj_attn.weight
"
)
new_item
=
new_item
.
replace
(
"
proj_out.bias
"
,
"
proj_attn.bias
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
...
...
@@ -96,31 +100,31 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
checkpoint
[
path_map
[
'
query
'
]]
=
query
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
'
key
'
]]
=
key
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
'
value
'
]]
=
value
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
"
query
"
]]
=
query
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
"
key
"
]]
=
key
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
"
value
"
]]
=
value
.
reshape
(
target_shape
)
for
path
in
paths
:
new_path
=
path
[
'
new
'
]
new_path
=
path
[
"
new
"
]
# These have already been assigned
if
attention_paths_to_split
is
not
None
and
new_path
in
attention_paths_to_split
:
continue
# Global renaming happens here
new_path
=
new_path
.
replace
(
'
middle_block.0
'
,
'
mid.resnets.0
'
)
new_path
=
new_path
.
replace
(
'
middle_block.1
'
,
'
mid.attentions.0
'
)
new_path
=
new_path
.
replace
(
'
middle_block.2
'
,
'
mid.resnets.1
'
)
new_path
=
new_path
.
replace
(
"
middle_block.0
"
,
"
mid.resnets.0
"
)
new_path
=
new_path
.
replace
(
"
middle_block.1
"
,
"
mid.attentions.0
"
)
new_path
=
new_path
.
replace
(
"
middle_block.2
"
,
"
mid.resnets.1
"
)
if
additional_replacements
is
not
None
:
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'
old
'
],
replacement
[
'
new
'
])
new_path
=
new_path
.
replace
(
replacement
[
"
old
"
],
replacement
[
"
new
"
])
# proj_attn.weight has to be converted from conv 1D to linear
if
"proj_attn.weight"
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]][:,
:,
0
]
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]][:,
:,
0
]
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]]
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]]
def
convert_ldm_checkpoint
(
checkpoint
,
config
):
...
...
@@ -129,60 +133,78 @@ def convert_ldm_checkpoint(checkpoint, config):
"""
new_checkpoint
=
{}
new_checkpoint
[
'
time_embedding.linear_1.weight
'
]
=
checkpoint
[
'
time_embed.0.weight
'
]
new_checkpoint
[
'
time_embedding.linear_1.bias
'
]
=
checkpoint
[
'
time_embed.0.bias
'
]
new_checkpoint
[
'
time_embedding.linear_2.weight
'
]
=
checkpoint
[
'
time_embed.2.weight
'
]
new_checkpoint
[
'
time_embedding.linear_2.bias
'
]
=
checkpoint
[
'
time_embed.2.bias
'
]
new_checkpoint
[
"
time_embedding.linear_1.weight
"
]
=
checkpoint
[
"
time_embed.0.weight
"
]
new_checkpoint
[
"
time_embedding.linear_1.bias
"
]
=
checkpoint
[
"
time_embed.0.bias
"
]
new_checkpoint
[
"
time_embedding.linear_2.weight
"
]
=
checkpoint
[
"
time_embed.2.weight
"
]
new_checkpoint
[
"
time_embedding.linear_2.bias
"
]
=
checkpoint
[
"
time_embed.2.bias
"
]
new_checkpoint
[
'
conv_in.weight
'
]
=
checkpoint
[
'
input_blocks.0.0.weight
'
]
new_checkpoint
[
'
conv_in.bias
'
]
=
checkpoint
[
'
input_blocks.0.0.bias
'
]
new_checkpoint
[
"
conv_in.weight
"
]
=
checkpoint
[
"
input_blocks.0.0.weight
"
]
new_checkpoint
[
"
conv_in.bias
"
]
=
checkpoint
[
"
input_blocks.0.0.bias
"
]
new_checkpoint
[
'
conv_norm_out.weight
'
]
=
checkpoint
[
'
out.0.weight
'
]
new_checkpoint
[
'
conv_norm_out.bias
'
]
=
checkpoint
[
'
out.0.bias
'
]
new_checkpoint
[
'
conv_out.weight
'
]
=
checkpoint
[
'
out.2.weight
'
]
new_checkpoint
[
'
conv_out.bias
'
]
=
checkpoint
[
'
out.2.bias
'
]
new_checkpoint
[
"
conv_norm_out.weight
"
]
=
checkpoint
[
"
out.0.weight
"
]
new_checkpoint
[
"
conv_norm_out.bias
"
]
=
checkpoint
[
"
out.0.bias
"
]
new_checkpoint
[
"
conv_out.weight
"
]
=
checkpoint
[
"
out.2.weight
"
]
new_checkpoint
[
"
conv_out.bias
"
]
=
checkpoint
[
"
out.2.bias
"
]
# Retrieves the keys for the input blocks only
num_input_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'input_blocks'
in
layer
})
input_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'input_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_input_blocks
)}
num_input_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"input_blocks"
in
layer
})
input_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"input_blocks.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_input_blocks
)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'middle_block'
in
layer
})
middle_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'middle_block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_middle_blocks
)}
num_middle_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"middle_block"
in
layer
})
middle_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"middle_block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_middle_blocks
)
}
# Retrieves the keys for the output blocks only
num_output_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'output_blocks'
in
layer
})
output_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'output_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_output_blocks
)}
num_output_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"output_blocks"
in
layer
})
output_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"output_blocks.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_output_blocks
)
}
for
i
in
range
(
1
,
num_input_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'
num_res_blocks
'
]
+
1
)
layer_in_block_id
=
(
i
-
1
)
%
(
config
[
'
num_res_blocks
'
]
+
1
)
block_id
=
(
i
-
1
)
//
(
config
[
"
num_res_blocks
"
]
+
1
)
layer_in_block_id
=
(
i
-
1
)
%
(
config
[
"
num_res_blocks
"
]
+
1
)
resnets
=
[
key
for
key
in
input_blocks
[
i
]
if
f
'
input_blocks.
{
i
}
.0
'
in
key
]
attentions
=
[
key
for
key
in
input_blocks
[
i
]
if
f
'
input_blocks.
{
i
}
.1
'
in
key
]
resnets
=
[
key
for
key
in
input_blocks
[
i
]
if
f
"
input_blocks.
{
i
}
.0
"
in
key
]
attentions
=
[
key
for
key
in
input_blocks
[
i
]
if
f
"
input_blocks.
{
i
}
.1
"
in
key
]
if
f
'input_blocks.
{
i
}
.0.op.weight'
in
checkpoint
:
new_checkpoint
[
f
'downsample_blocks.
{
block_id
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'input_blocks.
{
i
}
.0.op.weight'
]
new_checkpoint
[
f
'downsample_blocks.
{
block_id
}
.downsamplers.0.conv.bias'
]
=
checkpoint
[
f
'input_blocks.
{
i
}
.0.op.bias'
]
if
f
"input_blocks.
{
i
}
.0.op.weight"
in
checkpoint
:
new_checkpoint
[
f
"downsample_blocks.
{
block_id
}
.downsamplers.0.conv.weight"
]
=
checkpoint
[
f
"input_blocks.
{
i
}
.0.op.weight"
]
new_checkpoint
[
f
"downsample_blocks.
{
block_id
}
.downsamplers.0.conv.bias"
]
=
checkpoint
[
f
"input_blocks.
{
i
}
.0.op.bias"
]
paths
=
renew_resnet_paths
(
resnets
)
meta_path
=
{
'old'
:
f
'input_blocks.
{
i
}
.0'
,
'new'
:
f
'downsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
resnet_op
=
{
'old'
:
'resnets.2.op'
,
'new'
:
'downsamplers.0.op'
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
,
resnet_op
],
config
=
config
)
meta_path
=
{
"old"
:
f
"input_blocks.
{
i
}
.0"
,
"new"
:
f
"downsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
"
}
resnet_op
=
{
"old"
:
"resnets.2.op"
,
"new"
:
"downsamplers.0.op"
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
,
resnet_op
],
config
=
config
)
if
len
(
attentions
):
paths
=
renew_attention_paths
(
attentions
)
meta_path
=
{
'old'
:
f
'input_blocks.
{
i
}
.1'
,
'new'
:
f
'downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
'
}
meta_path
=
{
"old"
:
f
"input_blocks.
{
i
}
.1"
,
"new"
:
f
"downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
"
,
}
to_split
=
{
f
'
input_blocks.
{
i
}
.1.qkv.bias
'
:
{
'
key
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
'
,
'
query
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
'
,
'
value
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
'
,
f
"
input_blocks.
{
i
}
.1.qkv.bias
"
:
{
"
key
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
"
,
"
query
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
"
,
"
value
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
"
,
},
f
'
input_blocks.
{
i
}
.1.qkv.weight
'
:
{
'
key
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
'
,
'
query
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
'
,
'
value
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
'
,
f
"
input_blocks.
{
i
}
.1.qkv.weight
"
:
{
"
key
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
"
,
"
query
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
"
,
"
value
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
"
,
},
}
assign_to_checkpoint
(
...
...
@@ -191,7 +213,7 @@ def convert_ldm_checkpoint(checkpoint, config):
checkpoint
,
additional_replacements
=
[
meta_path
],
attention_paths_to_split
=
to_split
,
config
=
config
config
=
config
,
)
resnet_0
=
middle_blocks
[
0
]
...
...
@@ -206,46 +228,52 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions_paths
=
renew_attention_paths
(
attentions
)
to_split
=
{
'
middle_block.1.qkv.bias
'
:
{
'
key
'
:
'
mid_block.attentions.0.key.bias
'
,
'
query
'
:
'
mid_block.attentions.0.query.bias
'
,
'
value
'
:
'
mid_block.attentions.0.value.bias
'
,
"
middle_block.1.qkv.bias
"
:
{
"
key
"
:
"
mid_block.attentions.0.key.bias
"
,
"
query
"
:
"
mid_block.attentions.0.query.bias
"
,
"
value
"
:
"
mid_block.attentions.0.value.bias
"
,
},
'
middle_block.1.qkv.weight
'
:
{
'
key
'
:
'
mid_block.attentions.0.key.weight
'
,
'
query
'
:
'
mid_block.attentions.0.query.weight
'
,
'
value
'
:
'
mid_block.attentions.0.value.weight
'
,
"
middle_block.1.qkv.weight
"
:
{
"
key
"
:
"
mid_block.attentions.0.key.weight
"
,
"
query
"
:
"
mid_block.attentions.0.query.weight
"
,
"
value
"
:
"
mid_block.attentions.0.value.weight
"
,
},
}
assign_to_checkpoint
(
attentions_paths
,
new_checkpoint
,
checkpoint
,
attention_paths_to_split
=
to_split
,
config
=
config
)
assign_to_checkpoint
(
attentions_paths
,
new_checkpoint
,
checkpoint
,
attention_paths_to_split
=
to_split
,
config
=
config
)
for
i
in
range
(
num_output_blocks
):
block_id
=
i
//
(
config
[
'
num_res_blocks
'
]
+
1
)
layer_in_block_id
=
i
%
(
config
[
'
num_res_blocks
'
]
+
1
)
block_id
=
i
//
(
config
[
"
num_res_blocks
"
]
+
1
)
layer_in_block_id
=
i
%
(
config
[
"
num_res_blocks
"
]
+
1
)
output_block_layers
=
[
shave_segments
(
name
,
2
)
for
name
in
output_blocks
[
i
]]
output_block_list
=
{}
for
layer
in
output_block_layers
:
layer_id
,
layer_name
=
layer
.
split
(
'.'
)[
0
],
shave_segments
(
layer
,
1
)
layer_id
,
layer_name
=
layer
.
split
(
"."
)[
0
],
shave_segments
(
layer
,
1
)
if
layer_id
in
output_block_list
:
output_block_list
[
layer_id
].
append
(
layer_name
)
else
:
output_block_list
[
layer_id
]
=
[
layer_name
]
if
len
(
output_block_list
)
>
1
:
resnets
=
[
key
for
key
in
output_blocks
[
i
]
if
f
'
output_blocks.
{
i
}
.0
'
in
key
]
attentions
=
[
key
for
key
in
output_blocks
[
i
]
if
f
'
output_blocks.
{
i
}
.1
'
in
key
]
resnets
=
[
key
for
key
in
output_blocks
[
i
]
if
f
"
output_blocks.
{
i
}
.0
"
in
key
]
attentions
=
[
key
for
key
in
output_blocks
[
i
]
if
f
"
output_blocks.
{
i
}
.1
"
in
key
]
resnet_0_paths
=
renew_resnet_paths
(
resnets
)
paths
=
renew_resnet_paths
(
resnets
)
meta_path
=
{
'
old
'
:
f
'
output_blocks.
{
i
}
.0
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
meta_path
=
{
"
old
"
:
f
"
output_blocks.
{
i
}
.0
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
"
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
config
=
config
)
if
[
'conv.weight'
,
'conv.bias'
]
in
output_block_list
.
values
():
index
=
list
(
output_block_list
.
values
()).
index
([
'conv.weight'
,
'conv.bias'
])
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.weight'
]
=
checkpoint
[
f
'output_blocks.
{
i
}
.
{
index
}
.conv.weight'
]
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.bias'
]
=
checkpoint
[
f
'output_blocks.
{
i
}
.
{
index
}
.conv.bias'
]
if
[
"conv.weight"
,
"conv.bias"
]
in
output_block_list
.
values
():
index
=
list
(
output_block_list
.
values
()).
index
([
"conv.weight"
,
"conv.bias"
])
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.weight"
]
=
checkpoint
[
f
"output_blocks.
{
i
}
.
{
index
}
.conv.weight"
]
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.bias"
]
=
checkpoint
[
f
"output_blocks.
{
i
}
.
{
index
}
.conv.bias"
]
# Clear attentions as they have been attributed above.
if
len
(
attentions
)
==
2
:
...
...
@@ -254,19 +282,19 @@ def convert_ldm_checkpoint(checkpoint, config):
if
len
(
attentions
):
paths
=
renew_attention_paths
(
attentions
)
meta_path
=
{
'
old
'
:
f
'
output_blocks.
{
i
}
.1
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
'
"
old
"
:
f
"
output_blocks.
{
i
}
.1
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
"
,
}
to_split
=
{
f
'
output_blocks.
{
i
}
.1.qkv.bias
'
:
{
'
key
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
'
,
'
query
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
'
,
'
value
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
'
,
f
"
output_blocks.
{
i
}
.1.qkv.bias
"
:
{
"
key
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
"
,
"
query
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
"
,
"
value
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
"
,
},
f
'
output_blocks.
{
i
}
.1.qkv.weight
'
:
{
'
key
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
'
,
'
query
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
'
,
'
value
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
'
,
f
"
output_blocks.
{
i
}
.1.qkv.weight
"
:
{
"
key
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
"
,
"
query
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
"
,
"
value
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
"
,
},
}
assign_to_checkpoint
(
...
...
@@ -274,14 +302,14 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
attention_paths_to_split
=
to_split
if
any
(
'
qkv
'
in
key
for
key
in
attentions
)
else
None
,
attention_paths_to_split
=
to_split
if
any
(
"
qkv
"
in
key
for
key
in
attentions
)
else
None
,
config
=
config
,
)
else
:
resnet_0_paths
=
renew_resnet_paths
(
output_block_layers
,
n_shave_prefix_segments
=
1
)
for
path
in
resnet_0_paths
:
old_path
=
'.'
.
join
([
'
output_blocks
'
,
str
(
i
),
path
[
'
old
'
]])
new_path
=
'.'
.
join
([
'
up_blocks
'
,
str
(
block_id
),
'
resnets
'
,
str
(
layer_in_block_id
),
path
[
'
new
'
]])
old_path
=
"."
.
join
([
"
output_blocks
"
,
str
(
i
),
path
[
"
old
"
]])
new_path
=
"."
.
join
([
"
up_blocks
"
,
str
(
block_id
),
"
resnets
"
,
str
(
layer_in_block_id
),
path
[
"
new
"
]])
new_checkpoint
[
new_path
]
=
checkpoint
[
old_path
]
...
...
@@ -303,9 +331,7 @@ if __name__ == "__main__":
help
=
"The config json file corresponding to the architecture."
,
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
...
...
scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
View file @
89793a97
...
...
@@ -16,8 +16,10 @@
import
argparse
import
json
import
torch
from
diffusers
import
UNet2DModel
from
diffusers
import
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
UNet2DModel
def
convert_ncsnpp_checkpoint
(
checkpoint
,
config
):
...
...
scripts/generate_logits.py
View file @
89793a97
from
huggingface_hub
import
HfApi
from
transformers.file_utils
import
has_file
from
diffusers
import
UNet2DModel
import
random
import
torch
from
diffusers
import
UNet2DModel
from
huggingface_hub
import
HfApi
api
=
HfApi
()
results
=
{}
results
[
"google_ddpm_cifar10_32"
]
=
torch
.
tensor
([
-
0.7515
,
-
1.6883
,
0.2420
,
0.0300
,
0.6347
,
1.3433
,
-
1.1743
,
-
3.7467
,
1.2342
,
-
2.2485
,
0.4636
,
0.8076
,
-
0.7991
,
0.3969
,
0.8498
,
0.9189
,
-
1.8887
,
-
3.3522
,
0.7639
,
0.2040
,
0.6271
,
-
2.7148
,
-
1.6316
,
3.0839
,
0.3186
,
0.2721
,
-
0.9759
,
-
1.2461
,
2.6257
,
1.3557
])
results
[
"google_ddpm_ema_bedroom_256"
]
=
torch
.
tensor
([
-
2.3639
,
-
2.5344
,
0.0054
,
-
0.6674
,
1.5990
,
1.0158
,
0.3124
,
-
2.1436
,
1.8795
,
-
2.5429
,
-
0.1566
,
-
0.3973
,
1.2490
,
2.6447
,
1.2283
,
-
0.5208
,
-
2.8154
,
-
3.5119
,
2.3838
,
1.2033
,
1.7201
,
-
2.1256
,
-
1.4576
,
2.7948
,
2.4204
,
-
0.9752
,
-
1.2546
,
0.8027
,
3.2758
,
3.1365
])
results
[
"CompVis_ldm_celebahq_256"
]
=
torch
.
tensor
([
-
0.6531
,
-
0.6891
,
-
0.3172
,
-
0.5375
,
-
0.9140
,
-
0.5367
,
-
0.1175
,
-
0.7869
,
-
0.3808
,
-
0.4513
,
-
0.2098
,
-
0.0083
,
0.3183
,
0.5140
,
0.2247
,
-
0.1304
,
-
0.1302
,
-
0.2802
,
-
0.2084
,
-
0.2025
,
-
0.4967
,
-
0.4873
,
-
0.0861
,
0.6925
,
0.0250
,
0.1290
,
-
0.1543
,
0.6316
,
1.0460
,
1.4943
])
results
[
"google_ncsnpp_ffhq_1024"
]
=
torch
.
tensor
([
0.0911
,
0.1107
,
0.0182
,
0.0435
,
-
0.0805
,
-
0.0608
,
0.0381
,
0.2172
,
-
0.0280
,
0.1327
,
-
0.0299
,
-
0.0255
,
-
0.0050
,
-
0.1170
,
-
0.1046
,
0.0309
,
0.1367
,
0.1728
,
-
0.0533
,
-
0.0748
,
-
0.0534
,
0.1624
,
0.0384
,
-
0.1805
,
-
0.0707
,
0.0642
,
0.0220
,
-
0.0134
,
-
0.1333
,
-
0.1505
])
results
[
"google_ncsnpp_bedroom_256"
]
=
torch
.
tensor
([
0.1321
,
0.1337
,
0.0440
,
0.0622
,
-
0.0591
,
-
0.0370
,
0.0503
,
0.2133
,
-
0.0177
,
0.1415
,
-
0.0116
,
-
0.0112
,
0.0044
,
-
0.0980
,
-
0.0789
,
0.0395
,
0.1502
,
0.1785
,
-
0.0488
,
-
0.0514
,
-
0.0404
,
0.1539
,
0.0454
,
-
0.1559
,
-
0.0665
,
0.0659
,
0.0383
,
-
0.0005
,
-
0.1266
,
-
0.1386
])
results
[
"google_ncsnpp_celebahq_256"
]
=
torch
.
tensor
([
0.1154
,
0.1218
,
0.0307
,
0.0526
,
-
0.0711
,
-
0.0541
,
0.0366
,
0.2078
,
-
0.0267
,
0.1317
,
-
0.0226
,
-
0.0193
,
-
0.0014
,
-
0.1055
,
-
0.0902
,
0.0330
,
0.1391
,
0.1709
,
-
0.0562
,
-
0.0693
,
-
0.0560
,
0.1482
,
0.0381
,
-
0.1683
,
-
0.0681
,
0.0661
,
0.0331
,
-
0.0046
,
-
0.1268
,
-
0.1431
])
results
[
"google_ncsnpp_church_256"
]
=
torch
.
tensor
([
0.1192
,
0.1240
,
0.0414
,
0.0606
,
-
0.0557
,
-
0.0412
,
0.0430
,
0.2042
,
-
0.0200
,
0.1385
,
-
0.0115
,
-
0.0132
,
0.0017
,
-
0.0965
,
-
0.0802
,
0.0398
,
0.1433
,
0.1747
,
-
0.0458
,
-
0.0533
,
-
0.0407
,
0.1545
,
0.0419
,
-
0.1574
,
-
0.0645
,
0.0626
,
0.0341
,
-
0.0010
,
-
0.1199
,
-
0.1390
])
results
[
"google_ncsnpp_ffhq_256"
]
=
torch
.
tensor
([
0.1075
,
0.1074
,
0.0205
,
0.0431
,
-
0.0774
,
-
0.0607
,
0.0298
,
0.2042
,
-
0.0320
,
0.1267
,
-
0.0281
,
-
0.0250
,
-
0.0064
,
-
0.1091
,
-
0.0946
,
0.0290
,
0.1328
,
0.1650
,
-
0.0580
,
-
0.0738
,
-
0.0586
,
0.1440
,
0.0337
,
-
0.1746
,
-
0.0712
,
0.0605
,
0.0250
,
-
0.0099
,
-
0.1316
,
-
0.1473
])
results
[
"google_ddpm_cat_256"
]
=
torch
.
tensor
([
-
1.4572
,
-
2.0481
,
-
0.0414
,
-
0.6005
,
1.4136
,
0.5848
,
0.4028
,
-
2.7330
,
1.2212
,
-
2.1228
,
0.2155
,
0.4039
,
0.7662
,
2.0535
,
0.7477
,
-
0.3243
,
-
2.1758
,
-
2.7648
,
1.6947
,
0.7026
,
1.2338
,
-
1.6078
,
-
0.8682
,
2.2810
,
1.8574
,
-
0.5718
,
-
0.5586
,
-
0.0186
,
2.3415
,
2.1251
])
results
[
"google_ddpm_celebahq_256"
]
=
torch
.
tensor
([
-
1.3690
,
-
1.9720
,
-
0.4090
,
-
0.6966
,
1.4660
,
0.9938
,
-
0.1385
,
-
2.7324
,
0.7736
,
-
1.8917
,
0.2923
,
0.4293
,
0.1693
,
1.4112
,
1.1887
,
-
0.3181
,
-
2.2160
,
-
2.6381
,
1.3170
,
0.8163
,
0.9240
,
-
1.6544
,
-
0.6099
,
2.5259
,
1.6430
,
-
0.9090
,
-
0.9392
,
-
0.0126
,
2.4268
,
2.3266
])
results
[
"google_ddpm_ema_celebahq_256"
]
=
torch
.
tensor
([
-
1.3525
,
-
1.9628
,
-
0.3956
,
-
0.6860
,
1.4664
,
1.0014
,
-
0.1259
,
-
2.7212
,
0.7772
,
-
1.8811
,
0.2996
,
0.4388
,
0.1704
,
1.4029
,
1.1701
,
-
0.3027
,
-
2.2053
,
-
2.6287
,
1.3350
,
0.8131
,
0.9274
,
-
1.6292
,
-
0.6098
,
2.5131
,
1.6505
,
-
0.8958
,
-
0.9298
,
-
0.0151
,
2.4257
,
2.3355
])
results
[
"google_ddpm_church_256"
]
=
torch
.
tensor
([
-
2.0585
,
-
2.7897
,
-
0.2850
,
-
0.8940
,
1.9052
,
0.5702
,
0.6345
,
-
3.8959
,
1.5932
,
-
3.2319
,
0.1974
,
0.0287
,
1.7566
,
2.6543
,
0.8387
,
-
0.5351
,
-
3.2736
,
-
4.3375
,
2.9029
,
1.6390
,
1.4640
,
-
2.1701
,
-
1.9013
,
2.9341
,
3.4981
,
-
0.6255
,
-
1.1644
,
-
0.1591
,
3.7097
,
3.2066
])
results
[
"google_ddpm_bedroom_256"
]
=
torch
.
tensor
([
-
2.3139
,
-
2.5594
,
-
0.0197
,
-
0.6785
,
1.7001
,
1.1606
,
0.3075
,
-
2.1740
,
1.8071
,
-
2.5630
,
-
0.0926
,
-
0.3811
,
1.2116
,
2.6246
,
1.2731
,
-
0.5398
,
-
2.8153
,
-
3.6140
,
2.3893
,
1.3262
,
1.6258
,
-
2.1856
,
-
1.3267
,
2.8395
,
2.3779
,
-
1.0623
,
-
1.2468
,
0.8959
,
3.3367
,
3.2243
])
results
[
"google_ddpm_ema_church_256"
]
=
torch
.
tensor
([
-
2.0628
,
-
2.7667
,
-
0.2089
,
-
0.8263
,
2.0539
,
0.5992
,
0.6495
,
-
3.8336
,
1.6025
,
-
3.2817
,
0.1721
,
-
0.0633
,
1.7516
,
2.7039
,
0.8100
,
-
0.5908
,
-
3.2113
,
-
4.4343
,
2.9257
,
1.3632
,
1.5562
,
-
2.1489
,
-
1.9894
,
3.0560
,
3.3396
,
-
0.7328
,
-
1.0417
,
0.0383
,
3.7093
,
3.2343
])
results
[
"google_ddpm_ema_cat_256"
]
=
torch
.
tensor
([
-
1.4574
,
-
2.0569
,
-
0.0473
,
-
0.6117
,
1.4018
,
0.5769
,
0.4129
,
-
2.7344
,
1.2241
,
-
2.1397
,
0.2000
,
0.3937
,
0.7616
,
2.0453
,
0.7324
,
-
0.3391
,
-
2.1746
,
-
2.7744
,
1.6963
,
0.6921
,
1.2187
,
-
1.6172
,
-
0.8877
,
2.2439
,
1.8471
,
-
0.5839
,
-
0.5605
,
-
0.0464
,
2.3250
,
2.1219
])
# fmt: off
results
[
"google_ddpm_cifar10_32"
]
=
torch
.
tensor
([
-
0.7515
,
-
1.6883
,
0.2420
,
0.0300
,
0.6347
,
1.3433
,
-
1.1743
,
-
3.7467
,
1.2342
,
-
2.2485
,
0.4636
,
0.8076
,
-
0.7991
,
0.3969
,
0.8498
,
0.9189
,
-
1.8887
,
-
3.3522
,
0.7639
,
0.2040
,
0.6271
,
-
2.7148
,
-
1.6316
,
3.0839
,
0.3186
,
0.2721
,
-
0.9759
,
-
1.2461
,
2.6257
,
1.3557
])
results
[
"google_ddpm_ema_bedroom_256"
]
=
torch
.
tensor
([
-
2.3639
,
-
2.5344
,
0.0054
,
-
0.6674
,
1.5990
,
1.0158
,
0.3124
,
-
2.1436
,
1.8795
,
-
2.5429
,
-
0.1566
,
-
0.3973
,
1.2490
,
2.6447
,
1.2283
,
-
0.5208
,
-
2.8154
,
-
3.5119
,
2.3838
,
1.2033
,
1.7201
,
-
2.1256
,
-
1.4576
,
2.7948
,
2.4204
,
-
0.9752
,
-
1.2546
,
0.8027
,
3.2758
,
3.1365
])
results
[
"CompVis_ldm_celebahq_256"
]
=
torch
.
tensor
([
-
0.6531
,
-
0.6891
,
-
0.3172
,
-
0.5375
,
-
0.9140
,
-
0.5367
,
-
0.1175
,
-
0.7869
,
-
0.3808
,
-
0.4513
,
-
0.2098
,
-
0.0083
,
0.3183
,
0.5140
,
0.2247
,
-
0.1304
,
-
0.1302
,
-
0.2802
,
-
0.2084
,
-
0.2025
,
-
0.4967
,
-
0.4873
,
-
0.0861
,
0.6925
,
0.0250
,
0.1290
,
-
0.1543
,
0.6316
,
1.0460
,
1.4943
])
results
[
"google_ncsnpp_ffhq_1024"
]
=
torch
.
tensor
([
0.0911
,
0.1107
,
0.0182
,
0.0435
,
-
0.0805
,
-
0.0608
,
0.0381
,
0.2172
,
-
0.0280
,
0.1327
,
-
0.0299
,
-
0.0255
,
-
0.0050
,
-
0.1170
,
-
0.1046
,
0.0309
,
0.1367
,
0.1728
,
-
0.0533
,
-
0.0748
,
-
0.0534
,
0.1624
,
0.0384
,
-
0.1805
,
-
0.0707
,
0.0642
,
0.0220
,
-
0.0134
,
-
0.1333
,
-
0.1505
])
results
[
"google_ncsnpp_bedroom_256"
]
=
torch
.
tensor
([
0.1321
,
0.1337
,
0.0440
,
0.0622
,
-
0.0591
,
-
0.0370
,
0.0503
,
0.2133
,
-
0.0177
,
0.1415
,
-
0.0116
,
-
0.0112
,
0.0044
,
-
0.0980
,
-
0.0789
,
0.0395
,
0.1502
,
0.1785
,
-
0.0488
,
-
0.0514
,
-
0.0404
,
0.1539
,
0.0454
,
-
0.1559
,
-
0.0665
,
0.0659
,
0.0383
,
-
0.0005
,
-
0.1266
,
-
0.1386
])
results
[
"google_ncsnpp_celebahq_256"
]
=
torch
.
tensor
([
0.1154
,
0.1218
,
0.0307
,
0.0526
,
-
0.0711
,
-
0.0541
,
0.0366
,
0.2078
,
-
0.0267
,
0.1317
,
-
0.0226
,
-
0.0193
,
-
0.0014
,
-
0.1055
,
-
0.0902
,
0.0330
,
0.1391
,
0.1709
,
-
0.0562
,
-
0.0693
,
-
0.0560
,
0.1482
,
0.0381
,
-
0.1683
,
-
0.0681
,
0.0661
,
0.0331
,
-
0.0046
,
-
0.1268
,
-
0.1431
])
results
[
"google_ncsnpp_church_256"
]
=
torch
.
tensor
([
0.1192
,
0.1240
,
0.0414
,
0.0606
,
-
0.0557
,
-
0.0412
,
0.0430
,
0.2042
,
-
0.0200
,
0.1385
,
-
0.0115
,
-
0.0132
,
0.0017
,
-
0.0965
,
-
0.0802
,
0.0398
,
0.1433
,
0.1747
,
-
0.0458
,
-
0.0533
,
-
0.0407
,
0.1545
,
0.0419
,
-
0.1574
,
-
0.0645
,
0.0626
,
0.0341
,
-
0.0010
,
-
0.1199
,
-
0.1390
])
results
[
"google_ncsnpp_ffhq_256"
]
=
torch
.
tensor
([
0.1075
,
0.1074
,
0.0205
,
0.0431
,
-
0.0774
,
-
0.0607
,
0.0298
,
0.2042
,
-
0.0320
,
0.1267
,
-
0.0281
,
-
0.0250
,
-
0.0064
,
-
0.1091
,
-
0.0946
,
0.0290
,
0.1328
,
0.1650
,
-
0.0580
,
-
0.0738
,
-
0.0586
,
0.1440
,
0.0337
,
-
0.1746
,
-
0.0712
,
0.0605
,
0.0250
,
-
0.0099
,
-
0.1316
,
-
0.1473
])
results
[
"google_ddpm_cat_256"
]
=
torch
.
tensor
([
-
1.4572
,
-
2.0481
,
-
0.0414
,
-
0.6005
,
1.4136
,
0.5848
,
0.4028
,
-
2.7330
,
1.2212
,
-
2.1228
,
0.2155
,
0.4039
,
0.7662
,
2.0535
,
0.7477
,
-
0.3243
,
-
2.1758
,
-
2.7648
,
1.6947
,
0.7026
,
1.2338
,
-
1.6078
,
-
0.8682
,
2.2810
,
1.8574
,
-
0.5718
,
-
0.5586
,
-
0.0186
,
2.3415
,
2.1251
])
results
[
"google_ddpm_celebahq_256"
]
=
torch
.
tensor
([
-
1.3690
,
-
1.9720
,
-
0.4090
,
-
0.6966
,
1.4660
,
0.9938
,
-
0.1385
,
-
2.7324
,
0.7736
,
-
1.8917
,
0.2923
,
0.4293
,
0.1693
,
1.4112
,
1.1887
,
-
0.3181
,
-
2.2160
,
-
2.6381
,
1.3170
,
0.8163
,
0.9240
,
-
1.6544
,
-
0.6099
,
2.5259
,
1.6430
,
-
0.9090
,
-
0.9392
,
-
0.0126
,
2.4268
,
2.3266
])
results
[
"google_ddpm_ema_celebahq_256"
]
=
torch
.
tensor
([
-
1.3525
,
-
1.9628
,
-
0.3956
,
-
0.6860
,
1.4664
,
1.0014
,
-
0.1259
,
-
2.7212
,
0.7772
,
-
1.8811
,
0.2996
,
0.4388
,
0.1704
,
1.4029
,
1.1701
,
-
0.3027
,
-
2.2053
,
-
2.6287
,
1.3350
,
0.8131
,
0.9274
,
-
1.6292
,
-
0.6098
,
2.5131
,
1.6505
,
-
0.8958
,
-
0.9298
,
-
0.0151
,
2.4257
,
2.3355
])
results
[
"google_ddpm_church_256"
]
=
torch
.
tensor
([
-
2.0585
,
-
2.7897
,
-
0.2850
,
-
0.8940
,
1.9052
,
0.5702
,
0.6345
,
-
3.8959
,
1.5932
,
-
3.2319
,
0.1974
,
0.0287
,
1.7566
,
2.6543
,
0.8387
,
-
0.5351
,
-
3.2736
,
-
4.3375
,
2.9029
,
1.6390
,
1.4640
,
-
2.1701
,
-
1.9013
,
2.9341
,
3.4981
,
-
0.6255
,
-
1.1644
,
-
0.1591
,
3.7097
,
3.2066
])
results
[
"google_ddpm_bedroom_256"
]
=
torch
.
tensor
([
-
2.3139
,
-
2.5594
,
-
0.0197
,
-
0.6785
,
1.7001
,
1.1606
,
0.3075
,
-
2.1740
,
1.8071
,
-
2.5630
,
-
0.0926
,
-
0.3811
,
1.2116
,
2.6246
,
1.2731
,
-
0.5398
,
-
2.8153
,
-
3.6140
,
2.3893
,
1.3262
,
1.6258
,
-
2.1856
,
-
1.3267
,
2.8395
,
2.3779
,
-
1.0623
,
-
1.2468
,
0.8959
,
3.3367
,
3.2243
])
results
[
"google_ddpm_ema_church_256"
]
=
torch
.
tensor
([
-
2.0628
,
-
2.7667
,
-
0.2089
,
-
0.8263
,
2.0539
,
0.5992
,
0.6495
,
-
3.8336
,
1.6025
,
-
3.2817
,
0.1721
,
-
0.0633
,
1.7516
,
2.7039
,
0.8100
,
-
0.5908
,
-
3.2113
,
-
4.4343
,
2.9257
,
1.3632
,
1.5562
,
-
2.1489
,
-
1.9894
,
3.0560
,
3.3396
,
-
0.7328
,
-
1.0417
,
0.0383
,
3.7093
,
3.2343
])
results
[
"google_ddpm_ema_cat_256"
]
=
torch
.
tensor
([
-
1.4574
,
-
2.0569
,
-
0.0473
,
-
0.6117
,
1.4018
,
0.5769
,
0.4129
,
-
2.7344
,
1.2241
,
-
2.1397
,
0.2000
,
0.3937
,
0.7616
,
2.0453
,
0.7324
,
-
0.3391
,
-
2.1746
,
-
2.7744
,
1.6963
,
0.6921
,
1.2187
,
-
1.6172
,
-
0.8877
,
2.2439
,
1.8471
,
-
0.5839
,
-
0.5605
,
-
0.0464
,
2.3250
,
2.1219
])
# fmt: on
models
=
api
.
list_models
(
filter
=
"diffusers"
)
for
mod
in
models
:
if
"google"
in
mod
.
author
or
mod
.
modelId
==
"CompVis/ldm-celebahq-256"
:
if
"google"
in
mod
.
author
or
mod
.
modelId
==
"CompVis/ldm-celebahq-256"
:
local_checkpoint
=
"/home/patrick/google_checkpoints/"
+
mod
.
modelId
.
split
(
"/"
)[
-
1
]
print
(
f
"Started running
{
mod
.
modelId
}
!!!"
)
if
mod
.
modelId
.
startswith
(
"CompVis"
):
model
=
UNet2DModel
.
from_pretrained
(
local_checkpoint
,
subfolder
=
"unet"
)
else
:
model
=
UNet2DModel
.
from_pretrained
(
local_checkpoint
,
subfolder
=
"unet"
)
else
:
model
=
UNet2DModel
.
from_pretrained
(
local_checkpoint
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
sample_size
,
model
.
config
.
sample_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
logits
=
model
(
noise
,
time_step
)[
'
sample
'
]
logits
=
model
(
noise
,
time_step
)[
"
sample
"
]
assert
torch
.
allclose
(
logits
[
0
,
0
,
0
,
:
30
],
results
[
"_"
.
join
(
"_"
.
join
(
mod
.
modelId
.
split
(
"/"
)).
split
(
"-"
))],
atol
=
1e-3
)
assert
torch
.
allclose
(
logits
[
0
,
0
,
0
,
:
30
],
results
[
"_"
.
join
(
"_"
.
join
(
mod
.
modelId
.
split
(
"/"
)).
split
(
"-"
))],
atol
=
1e-3
)
print
(
f
"
{
mod
.
modelId
}
has passed succesfully!!!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment