Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
89793a97
Unverified
Commit
89793a97
authored
Aug 25, 2022
by
Anton Lozhkov
Committed by
GitHub
Aug 25, 2022
Browse files
Style the `scripts` directory (#250)
Style scripts
parent
365f7523
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
452 additions
and
315 deletions
+452
-315
Makefile
Makefile
+1
-1
scripts/change_naming_configs_and_checkpoints.py
scripts/change_naming_configs_and_checkpoints.py
+6
-5
scripts/conversion_ldm_uncond.py
scripts/conversion_ldm_uncond.py
+5
-5
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
+208
-136
scripts/convert_ldm_original_checkpoint_to_diffusers.py
scripts/convert_ldm_original_checkpoint_to_diffusers.py
+123
-97
scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
+3
-1
scripts/generate_logits.py
scripts/generate_logits.py
+106
-70
No files found.
Makefile
View file @
89793a97
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export
PYTHONPATH
=
src
export
PYTHONPATH
=
src
check_dirs
:=
examples
tes
ts src utils
check_dirs
:=
examples
scrip
ts src
tests
utils
modified_only_fixup
:
modified_only_fixup
:
$(
eval
modified_py_files :
=
$(
shell
python utils/get_modified_files.py
$(check_dirs)
))
$(
eval
modified_py_files :
=
$(
shell
python utils/get_modified_files.py
$(check_dirs)
))
...
...
scripts/change_naming_configs_and_checkpoints.py
View file @
89793a97
...
@@ -15,12 +15,15 @@
...
@@ -15,12 +15,15 @@
""" Conversion script for the LDM checkpoints. """
""" Conversion script for the LDM checkpoints. """
import
argparse
import
argparse
import
os
import
json
import
json
import
os
import
torch
import
torch
from
diffusers
import
UNet2DModel
,
UNet2DConditionModel
from
diffusers
import
UNet2DConditionModel
,
UNet2DModel
from
transformers.file_utils
import
has_file
from
transformers.file_utils
import
has_file
do_only_config
=
False
do_only_config
=
False
do_only_weights
=
True
do_only_weights
=
True
do_only_renaming
=
False
do_only_renaming
=
False
...
@@ -37,9 +40,7 @@ if __name__ == "__main__":
...
@@ -37,9 +40,7 @@ if __name__ == "__main__":
help
=
"The config json file corresponding to the architecture."
,
help
=
"The config json file corresponding to the architecture."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
scripts/conversion_ldm_uncond.py
View file @
89793a97
import
argparse
import
argparse
import
OmegaConf
import
torch
import
torch
from
diffusers
import
UNetLDMModel
,
VQModel
,
LDMPipeline
,
DDIMScheduler
import
OmegaConf
from
diffusers
import
DDIMScheduler
,
LDMPipeline
,
UNetLDMModel
,
VQModel
def
convert_ldm_original
(
checkpoint_path
,
config_path
,
output_path
):
def
convert_ldm_original
(
checkpoint_path
,
config_path
,
output_path
):
config
=
OmegaConf
.
load
(
config_path
)
config
=
OmegaConf
.
load
(
config_path
)
...
@@ -16,14 +17,14 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
...
@@ -16,14 +17,14 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
for
key
in
keys
:
for
key
in
keys
:
if
key
.
startswith
(
first_stage_key
):
if
key
.
startswith
(
first_stage_key
):
first_stage_dict
[
key
.
replace
(
first_stage_key
,
""
)]
=
state_dict
[
key
]
first_stage_dict
[
key
.
replace
(
first_stage_key
,
""
)]
=
state_dict
[
key
]
# extract state_dict for UNetLDM
# extract state_dict for UNetLDM
unet_state_dict
=
{}
unet_state_dict
=
{}
unet_key
=
"model.diffusion_model."
unet_key
=
"model.diffusion_model."
for
key
in
keys
:
for
key
in
keys
:
if
key
.
startswith
(
unet_key
):
if
key
.
startswith
(
unet_key
):
unet_state_dict
[
key
.
replace
(
unet_key
,
""
)]
=
state_dict
[
key
]
unet_state_dict
[
key
.
replace
(
unet_key
,
""
)]
=
state_dict
[
key
]
vqvae_init_args
=
config
.
model
.
params
.
first_stage_config
.
params
vqvae_init_args
=
config
.
model
.
params
.
first_stage_config
.
params
unet_init_args
=
config
.
model
.
params
.
unet_config
.
params
unet_init_args
=
config
.
model
.
params
.
unet_config
.
params
...
@@ -53,4 +54,3 @@ if __name__ == "__main__":
...
@@ -53,4 +54,3 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
convert_ldm_original
(
args
.
checkpoint_path
,
args
.
config_path
,
args
.
output_path
)
convert_ldm_original
(
args
.
checkpoint_path
,
args
.
config_path
,
args
.
output_path
)
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
View file @
89793a97
from
diffusers
import
UNet2DModel
,
DDPMScheduler
,
DDPMPipeline
,
VQModel
,
AutoencoderKL
import
argparse
import
argparse
import
json
import
json
import
torch
import
torch
from
diffusers
import
AutoencoderKL
,
DDPMPipeline
,
DDPMScheduler
,
UNet2DModel
,
VQModel
def
shave_segments
(
path
,
n_shave_prefix_segments
=
1
):
def
shave_segments
(
path
,
n_shave_prefix_segments
=
1
):
"""
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
"""
if
n_shave_prefix_segments
>=
0
:
if
n_shave_prefix_segments
>=
0
:
return
'.'
.
join
(
path
.
split
(
'.'
)[
n_shave_prefix_segments
:])
return
"."
.
join
(
path
.
split
(
"."
)[
n_shave_prefix_segments
:])
else
:
else
:
return
'.'
.
join
(
path
.
split
(
'.'
)[:
n_shave_prefix_segments
])
return
"."
.
join
(
path
.
split
(
"."
)[:
n_shave_prefix_segments
])
def
renew_resnet_paths
(
old_list
,
n_shave_prefix_segments
=
0
):
def
renew_resnet_paths
(
old_list
,
n_shave_prefix_segments
=
0
):
mapping
=
[]
mapping
=
[]
for
old_item
in
old_list
:
for
old_item
in
old_list
:
new_item
=
old_item
new_item
=
old_item
new_item
=
new_item
.
replace
(
'
block.
'
,
'
resnets.
'
)
new_item
=
new_item
.
replace
(
"
block.
"
,
"
resnets.
"
)
new_item
=
new_item
.
replace
(
'
conv_shorcut
'
,
'
conv1
'
)
new_item
=
new_item
.
replace
(
"
conv_shorcut
"
,
"
conv1
"
)
new_item
=
new_item
.
replace
(
'
nin_shortcut
'
,
'
conv_shortcut
'
)
new_item
=
new_item
.
replace
(
"
nin_shortcut
"
,
"
conv_shortcut
"
)
new_item
=
new_item
.
replace
(
'
temb_proj
'
,
'
time_emb_proj
'
)
new_item
=
new_item
.
replace
(
"
temb_proj
"
,
"
time_emb_proj
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
return
mapping
...
@@ -37,21 +39,23 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
...
@@ -37,21 +39,23 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
# In `model.mid`, the layer is called `attn`.
# In `model.mid`, the layer is called `attn`.
if
not
in_mid
:
if
not
in_mid
:
new_item
=
new_item
.
replace
(
'
attn
'
,
'
attentions
'
)
new_item
=
new_item
.
replace
(
"
attn
"
,
"
attentions
"
)
new_item
=
new_item
.
replace
(
'
.k.
'
,
'
.key.
'
)
new_item
=
new_item
.
replace
(
"
.k.
"
,
"
.key.
"
)
new_item
=
new_item
.
replace
(
'
.v.
'
,
'
.value.
'
)
new_item
=
new_item
.
replace
(
"
.v.
"
,
"
.value.
"
)
new_item
=
new_item
.
replace
(
'
.q.
'
,
'
.query.
'
)
new_item
=
new_item
.
replace
(
"
.q.
"
,
"
.query.
"
)
new_item
=
new_item
.
replace
(
'
proj_out
'
,
'
proj_attn
'
)
new_item
=
new_item
.
replace
(
"
proj_out
"
,
"
proj_attn
"
)
new_item
=
new_item
.
replace
(
'
norm
'
,
'
group_norm
'
)
new_item
=
new_item
.
replace
(
"
norm
"
,
"
group_norm
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
return
mapping
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
assert
isinstance
(
paths
,
list
),
"Paths should be a list of dicts containing 'old' and 'new' keys."
assert
isinstance
(
paths
,
list
),
"Paths should be a list of dicts containing 'old' and 'new' keys."
if
attention_paths_to_split
is
not
None
:
if
attention_paths_to_split
is
not
None
:
...
@@ -69,27 +73,27 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
...
@@ -69,27 +73,27 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
checkpoint
[
path_map
[
'
query
'
]]
=
query
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
"
query
"
]]
=
query
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
'
key
'
]]
=
key
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
"
key
"
]]
=
key
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
'
value
'
]]
=
value
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
"
value
"
]]
=
value
.
reshape
(
target_shape
).
squeeze
()
for
path
in
paths
:
for
path
in
paths
:
new_path
=
path
[
'
new
'
]
new_path
=
path
[
"
new
"
]
if
attention_paths_to_split
is
not
None
and
new_path
in
attention_paths_to_split
:
if
attention_paths_to_split
is
not
None
and
new_path
in
attention_paths_to_split
:
continue
continue
new_path
=
new_path
.
replace
(
'
down.
'
,
'
down_blocks.
'
)
new_path
=
new_path
.
replace
(
"
down.
"
,
"
down_blocks.
"
)
new_path
=
new_path
.
replace
(
'
up.
'
,
'
up_blocks.
'
)
new_path
=
new_path
.
replace
(
"
up.
"
,
"
up_blocks.
"
)
if
additional_replacements
is
not
None
:
if
additional_replacements
is
not
None
:
for
replacement
in
additional_replacements
:
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'
old
'
],
replacement
[
'
new
'
])
new_path
=
new_path
.
replace
(
replacement
[
"
old
"
],
replacement
[
"
new
"
])
if
'
attentions
'
in
new_path
:
if
"
attentions
"
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]].
squeeze
()
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]].
squeeze
()
else
:
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]]
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]]
def
convert_ddpm_checkpoint
(
checkpoint
,
config
):
def
convert_ddpm_checkpoint
(
checkpoint
,
config
):
...
@@ -98,49 +102,63 @@ def convert_ddpm_checkpoint(checkpoint, config):
...
@@ -98,49 +102,63 @@ def convert_ddpm_checkpoint(checkpoint, config):
"""
"""
new_checkpoint
=
{}
new_checkpoint
=
{}
new_checkpoint
[
'
time_embedding.linear_1.weight
'
]
=
checkpoint
[
'
temb.dense.0.weight
'
]
new_checkpoint
[
"
time_embedding.linear_1.weight
"
]
=
checkpoint
[
"
temb.dense.0.weight
"
]
new_checkpoint
[
'
time_embedding.linear_1.bias
'
]
=
checkpoint
[
'
temb.dense.0.bias
'
]
new_checkpoint
[
"
time_embedding.linear_1.bias
"
]
=
checkpoint
[
"
temb.dense.0.bias
"
]
new_checkpoint
[
'
time_embedding.linear_2.weight
'
]
=
checkpoint
[
'
temb.dense.1.weight
'
]
new_checkpoint
[
"
time_embedding.linear_2.weight
"
]
=
checkpoint
[
"
temb.dense.1.weight
"
]
new_checkpoint
[
'
time_embedding.linear_2.bias
'
]
=
checkpoint
[
'
temb.dense.1.bias
'
]
new_checkpoint
[
"
time_embedding.linear_2.bias
"
]
=
checkpoint
[
"
temb.dense.1.bias
"
]
new_checkpoint
[
'
conv_norm_out.weight
'
]
=
checkpoint
[
'
norm_out.weight
'
]
new_checkpoint
[
"
conv_norm_out.weight
"
]
=
checkpoint
[
"
norm_out.weight
"
]
new_checkpoint
[
'
conv_norm_out.bias
'
]
=
checkpoint
[
'
norm_out.bias
'
]
new_checkpoint
[
"
conv_norm_out.bias
"
]
=
checkpoint
[
"
norm_out.bias
"
]
new_checkpoint
[
'
conv_in.weight
'
]
=
checkpoint
[
'
conv_in.weight
'
]
new_checkpoint
[
"
conv_in.weight
"
]
=
checkpoint
[
"
conv_in.weight
"
]
new_checkpoint
[
'
conv_in.bias
'
]
=
checkpoint
[
'
conv_in.bias
'
]
new_checkpoint
[
"
conv_in.bias
"
]
=
checkpoint
[
"
conv_in.bias
"
]
new_checkpoint
[
'
conv_out.weight
'
]
=
checkpoint
[
'
conv_out.weight
'
]
new_checkpoint
[
"
conv_out.weight
"
]
=
checkpoint
[
"
conv_out.weight
"
]
new_checkpoint
[
'
conv_out.bias
'
]
=
checkpoint
[
'
conv_out.bias
'
]
new_checkpoint
[
"
conv_out.bias
"
]
=
checkpoint
[
"
conv_out.bias
"
]
num_down_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'down'
in
layer
})
num_down_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"down"
in
layer
})
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'down.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_down_blocks
)}
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"down.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_down_blocks
)
}
num_up_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'
up
'
in
layer
})
num_up_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"
up
"
in
layer
})
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'
up.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"
up.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
for
i
in
range
(
num_down_blocks
):
for
i
in
range
(
num_down_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'layers_per_block'
]
+
1
)
block_id
=
(
i
-
1
)
//
(
config
[
"layers_per_block"
]
+
1
)
if
any
(
'downsample'
in
layer
for
layer
in
down_blocks
[
i
]):
if
any
(
"downsample"
in
layer
for
layer
in
down_blocks
[
i
]):
new_checkpoint
[
f
'down_blocks.
{
i
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.op.weight'
]
new_checkpoint
[
f
"down_blocks.
{
i
}
.downsamplers.0.conv.weight"
]
=
checkpoint
[
new_checkpoint
[
f
'down_blocks.
{
i
}
.downsamplers.0.conv.bias'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.op.bias'
]
f
"down.
{
i
}
.downsample.op.weight"
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
]
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
new_checkpoint
[
f
"down_blocks.
{
i
}
.downsamplers.0.conv.bias"
]
=
checkpoint
[
f
"down.
{
i
}
.downsample.op.bias"
]
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
if
any
(
'block'
in
layer
for
layer
in
down_blocks
[
i
]):
# new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
any
(
"block"
in
layer
for
layer
in
down_blocks
[
i
]):
num_blocks
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_resnet_paths
(
blocks
[
j
])
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
if
any
(
'attn'
in
layer
for
layer
in
down_blocks
[
i
]):
if
any
(
"attn"
in
layer
for
layer
in
down_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
'attn'
in
layer
})
num_attn
=
len
(
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
down_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_attention_paths
(
attns
[
j
])
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
...
@@ -150,48 +168,67 @@ def convert_ddpm_checkpoint(checkpoint, config):
...
@@ -150,48 +168,67 @@ def convert_ddpm_checkpoint(checkpoint, config):
# Mid new 2
# Mid new 2
paths
=
renew_resnet_paths
(
mid_block_1_layers
)
paths
=
renew_resnet_paths
(
mid_block_1_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
assign_to_checkpoint
(
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_1'
,
'new'
:
'resnets.0'
}
paths
,
])
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_1"
,
"new"
:
"resnets.0"
}],
)
paths
=
renew_resnet_paths
(
mid_block_2_layers
)
paths
=
renew_resnet_paths
(
mid_block_2_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
assign_to_checkpoint
(
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_2'
,
'new'
:
'resnets.1'
}
paths
,
])
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_2"
,
"new"
:
"resnets.1"
}],
)
paths
=
renew_attention_paths
(
mid_attn_1_layers
,
in_mid
=
True
)
paths
=
renew_attention_paths
(
mid_attn_1_layers
,
in_mid
=
True
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
assign_to_checkpoint
(
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'attn_1'
,
'new'
:
'attentions.0'
}
paths
,
])
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"attn_1"
,
"new"
:
"attentions.0"
}],
)
for
i
in
range
(
num_up_blocks
):
for
i
in
range
(
num_up_blocks
):
block_id
=
num_up_blocks
-
1
-
i
block_id
=
num_up_blocks
-
1
-
i
if
any
(
'upsample'
in
layer
for
layer
in
up_blocks
[
i
]):
if
any
(
"upsample"
in
layer
for
layer
in
up_blocks
[
i
]):
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.weight'
]
=
checkpoint
[
f
'up.
{
i
}
.upsample.conv.weight'
]
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.weight"
]
=
checkpoint
[
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.bias'
]
=
checkpoint
[
f
'up.
{
i
}
.upsample.conv.bias'
]
f
"up.
{
i
}
.upsample.conv.weight"
]
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.bias"
]
=
checkpoint
[
f
"up.
{
i
}
.upsample.conv.bias"
]
if
any
(
'block'
in
layer
for
layer
in
up_blocks
[
i
]):
if
any
(
"block"
in
layer
for
layer
in
up_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
'block'
in
layer
})
num_blocks
=
len
(
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_resnet_paths
(
blocks
[
j
])
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
if
any
(
'attn'
in
layer
for
layer
in
up_blocks
[
i
]):
if
any
(
"attn"
in
layer
for
layer
in
up_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
'attn'
in
layer
})
num_attn
=
len
(
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
{
"."
.
join
(
shave_segments
(
layer
,
2
).
split
(
"."
)[:
2
])
for
layer
in
up_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_attention_paths
(
attns
[
j
])
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
new_checkpoint
=
{
k
.
replace
(
'
mid_new_2
'
,
'
mid_block
'
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
new_checkpoint
=
{
k
.
replace
(
"
mid_new_2
"
,
"
mid_block
"
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
return
new_checkpoint
return
new_checkpoint
...
@@ -201,50 +238,66 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
...
@@ -201,50 +238,66 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
"""
"""
new_checkpoint
=
{}
new_checkpoint
=
{}
new_checkpoint
[
'
encoder.conv_norm_out.weight
'
]
=
checkpoint
[
'
encoder.norm_out.weight
'
]
new_checkpoint
[
"
encoder.conv_norm_out.weight
"
]
=
checkpoint
[
"
encoder.norm_out.weight
"
]
new_checkpoint
[
'
encoder.conv_norm_out.bias
'
]
=
checkpoint
[
'
encoder.norm_out.bias
'
]
new_checkpoint
[
"
encoder.conv_norm_out.bias
"
]
=
checkpoint
[
"
encoder.norm_out.bias
"
]
new_checkpoint
[
'
encoder.conv_in.weight
'
]
=
checkpoint
[
'
encoder.conv_in.weight
'
]
new_checkpoint
[
"
encoder.conv_in.weight
"
]
=
checkpoint
[
"
encoder.conv_in.weight
"
]
new_checkpoint
[
'
encoder.conv_in.bias
'
]
=
checkpoint
[
'
encoder.conv_in.bias
'
]
new_checkpoint
[
"
encoder.conv_in.bias
"
]
=
checkpoint
[
"
encoder.conv_in.bias
"
]
new_checkpoint
[
'
encoder.conv_out.weight
'
]
=
checkpoint
[
'
encoder.conv_out.weight
'
]
new_checkpoint
[
"
encoder.conv_out.weight
"
]
=
checkpoint
[
"
encoder.conv_out.weight
"
]
new_checkpoint
[
'
encoder.conv_out.bias
'
]
=
checkpoint
[
'
encoder.conv_out.bias
'
]
new_checkpoint
[
"
encoder.conv_out.bias
"
]
=
checkpoint
[
"
encoder.conv_out.bias
"
]
new_checkpoint
[
'
decoder.conv_norm_out.weight
'
]
=
checkpoint
[
'
decoder.norm_out.weight
'
]
new_checkpoint
[
"
decoder.conv_norm_out.weight
"
]
=
checkpoint
[
"
decoder.norm_out.weight
"
]
new_checkpoint
[
'
decoder.conv_norm_out.bias
'
]
=
checkpoint
[
'
decoder.norm_out.bias
'
]
new_checkpoint
[
"
decoder.conv_norm_out.bias
"
]
=
checkpoint
[
"
decoder.norm_out.bias
"
]
new_checkpoint
[
'
decoder.conv_in.weight
'
]
=
checkpoint
[
'
decoder.conv_in.weight
'
]
new_checkpoint
[
"
decoder.conv_in.weight
"
]
=
checkpoint
[
"
decoder.conv_in.weight
"
]
new_checkpoint
[
'
decoder.conv_in.bias
'
]
=
checkpoint
[
'
decoder.conv_in.bias
'
]
new_checkpoint
[
"
decoder.conv_in.bias
"
]
=
checkpoint
[
"
decoder.conv_in.bias
"
]
new_checkpoint
[
'
decoder.conv_out.weight
'
]
=
checkpoint
[
'
decoder.conv_out.weight
'
]
new_checkpoint
[
"
decoder.conv_out.weight
"
]
=
checkpoint
[
"
decoder.conv_out.weight
"
]
new_checkpoint
[
'
decoder.conv_out.bias
'
]
=
checkpoint
[
'
decoder.conv_out.bias
'
]
new_checkpoint
[
"
decoder.conv_out.bias
"
]
=
checkpoint
[
"
decoder.conv_out.bias
"
]
num_down_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
3
])
for
layer
in
checkpoint
if
'down'
in
layer
})
num_down_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
3
])
for
layer
in
checkpoint
if
"down"
in
layer
})
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'down.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_down_blocks
)}
down_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"down.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_down_blocks
)
}
num_up_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
3
])
for
layer
in
checkpoint
if
'
up
'
in
layer
})
num_up_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
3
])
for
layer
in
checkpoint
if
"
up
"
in
layer
})
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'
up.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
up_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"
up.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_up_blocks
)}
for
i
in
range
(
num_down_blocks
):
for
i
in
range
(
num_down_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'layers_per_block'
]
+
1
)
block_id
=
(
i
-
1
)
//
(
config
[
"layers_per_block"
]
+
1
)
if
any
(
'downsample'
in
layer
for
layer
in
down_blocks
[
i
]):
if
any
(
"downsample"
in
layer
for
layer
in
down_blocks
[
i
]):
new_checkpoint
[
f
'encoder.down_blocks.
{
i
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'encoder.down.
{
i
}
.downsample.conv.weight'
]
new_checkpoint
[
f
"encoder.down_blocks.
{
i
}
.downsamplers.0.conv.weight"
]
=
checkpoint
[
new_checkpoint
[
f
'encoder.down_blocks.
{
i
}
.downsamplers.0.conv.bias'
]
=
checkpoint
[
f
'encoder.down.
{
i
}
.downsample.conv.bias'
]
f
"encoder.down.
{
i
}
.downsample.conv.weight"
]
if
any
(
'block'
in
layer
for
layer
in
down_blocks
[
i
]):
new_checkpoint
[
f
"encoder.down_blocks.
{
i
}
.downsamplers.0.conv.bias"
]
=
checkpoint
[
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
'block'
in
layer
})
f
"encoder.down.
{
i
}
.downsample.conv.bias"
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
]
if
any
(
"block"
in
layer
for
layer
in
down_blocks
[
i
]):
num_blocks
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_resnet_paths
(
blocks
[
j
])
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
if
any
(
'attn'
in
layer
for
layer
in
down_blocks
[
i
]):
if
any
(
"attn"
in
layer
for
layer
in
down_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
'attn'
in
layer
})
num_attn
=
len
(
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
down_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
down_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]):
for
j
in
range
(
config
[
"
layers_per_block
"
]):
paths
=
renew_attention_paths
(
attns
[
j
])
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
...
@@ -254,48 +307,69 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
...
@@ -254,48 +307,69 @@ def convert_vq_autoenc_checkpoint(checkpoint, config):
# Mid new 2
# Mid new 2
paths
=
renew_resnet_paths
(
mid_block_1_layers
)
paths
=
renew_resnet_paths
(
mid_block_1_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
assign_to_checkpoint
(
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_1'
,
'new'
:
'resnets.0'
}
paths
,
])
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_1"
,
"new"
:
"resnets.0"
}],
)
paths
=
renew_resnet_paths
(
mid_block_2_layers
)
paths
=
renew_resnet_paths
(
mid_block_2_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
assign_to_checkpoint
(
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_2'
,
'new'
:
'resnets.1'
}
paths
,
])
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"block_2"
,
"new"
:
"resnets.1"
}],
)
paths
=
renew_attention_paths
(
mid_attn_1_layers
,
in_mid
=
True
)
paths
=
renew_attention_paths
(
mid_attn_1_layers
,
in_mid
=
True
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
assign_to_checkpoint
(
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'attn_1'
,
'new'
:
'attentions.0'
}
paths
,
])
new_checkpoint
,
checkpoint
,
additional_replacements
=
[{
"old"
:
"mid."
,
"new"
:
"mid_new_2."
},
{
"old"
:
"attn_1"
,
"new"
:
"attentions.0"
}],
)
for
i
in
range
(
num_up_blocks
):
for
i
in
range
(
num_up_blocks
):
block_id
=
num_up_blocks
-
1
-
i
block_id
=
num_up_blocks
-
1
-
i
if
any
(
'upsample'
in
layer
for
layer
in
up_blocks
[
i
]):
if
any
(
"upsample"
in
layer
for
layer
in
up_blocks
[
i
]):
new_checkpoint
[
f
'decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.weight'
]
=
checkpoint
[
f
'decoder.up.
{
i
}
.upsample.conv.weight'
]
new_checkpoint
[
f
"decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.weight"
]
=
checkpoint
[
new_checkpoint
[
f
'decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.bias'
]
=
checkpoint
[
f
'decoder.up.
{
i
}
.upsample.conv.bias'
]
f
"decoder.up.
{
i
}
.upsample.conv.weight"
]
if
any
(
'block'
in
layer
for
layer
in
up_blocks
[
i
]):
new_checkpoint
[
f
"decoder.up_blocks.
{
block_id
}
.upsamplers.0.conv.bias"
]
=
checkpoint
[
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
'block'
in
layer
})
f
"decoder.up.
{
i
}
.upsample.conv.bias"
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
]
if
any
(
"block"
in
layer
for
layer
in
up_blocks
[
i
]):
num_blocks
=
len
(
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
"block"
in
layer
}
)
blocks
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_blocks
>
0
:
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_resnet_paths
(
blocks
[
j
])
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
if
any
(
'attn'
in
layer
for
layer
in
up_blocks
[
i
]):
if
any
(
"attn"
in
layer
for
layer
in
up_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
3
).
split
(
'.'
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
'attn'
in
layer
})
num_attn
=
len
(
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
{
"."
.
join
(
shave_segments
(
layer
,
3
).
split
(
"."
)[:
3
])
for
layer
in
up_blocks
[
i
]
if
"attn"
in
layer
}
)
attns
=
{
layer_id
:
[
key
for
key
in
up_blocks
[
i
]
if
f
"attn.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_blocks
)
}
if
num_attn
>
0
:
if
num_attn
>
0
:
for
j
in
range
(
config
[
'
layers_per_block
'
]
+
1
):
for
j
in
range
(
config
[
"
layers_per_block
"
]
+
1
):
replace_indices
=
{
'
old
'
:
f
'
up_blocks.
{
i
}
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
'
}
replace_indices
=
{
"
old
"
:
f
"
up_blocks.
{
i
}
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
"
}
paths
=
renew_attention_paths
(
attns
[
j
])
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
new_checkpoint
=
{
k
.
replace
(
'
mid_new_2
'
,
'
mid_block
'
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
new_checkpoint
=
{
k
.
replace
(
"
mid_new_2
"
,
"
mid_block
"
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
new_checkpoint
[
"quant_conv.weight"
]
=
checkpoint
[
"quant_conv.weight"
]
new_checkpoint
[
"quant_conv.weight"
]
=
checkpoint
[
"quant_conv.weight"
]
new_checkpoint
[
"quant_conv.bias"
]
=
checkpoint
[
"quant_conv.bias"
]
new_checkpoint
[
"quant_conv.bias"
]
=
checkpoint
[
"quant_conv.bias"
]
if
"quantize.embedding.weight"
in
checkpoint
:
if
"quantize.embedding.weight"
in
checkpoint
:
...
@@ -321,9 +395,7 @@ if __name__ == "__main__":
...
@@ -321,9 +395,7 @@ if __name__ == "__main__":
help
=
"The config json file corresponding to the architecture."
,
help
=
"The config json file corresponding to the architecture."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
...
...
scripts/convert_ldm_original_checkpoint_to_diffusers.py
View file @
89793a97
...
@@ -16,8 +16,10 @@
...
@@ -16,8 +16,10 @@
import
argparse
import
argparse
import
json
import
json
import
torch
import
torch
from
diffusers
import
VQModel
,
DDPMScheduler
,
UNet2DModel
,
LDMPipeline
from
diffusers
import
DDPMScheduler
,
LDMPipeline
,
UNet2DModel
,
VQModel
def
shave_segments
(
path
,
n_shave_prefix_segments
=
1
):
def
shave_segments
(
path
,
n_shave_prefix_segments
=
1
):
...
@@ -25,9 +27,9 @@ def shave_segments(path, n_shave_prefix_segments=1):
...
@@ -25,9 +27,9 @@ def shave_segments(path, n_shave_prefix_segments=1):
Removes segments. Positive values shave the first segments, negative shave the last segments.
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
"""
if
n_shave_prefix_segments
>=
0
:
if
n_shave_prefix_segments
>=
0
:
return
'.'
.
join
(
path
.
split
(
'.'
)[
n_shave_prefix_segments
:])
return
"."
.
join
(
path
.
split
(
"."
)[
n_shave_prefix_segments
:])
else
:
else
:
return
'.'
.
join
(
path
.
split
(
'.'
)[:
n_shave_prefix_segments
])
return
"."
.
join
(
path
.
split
(
"."
)[:
n_shave_prefix_segments
])
def
renew_resnet_paths
(
old_list
,
n_shave_prefix_segments
=
0
):
def
renew_resnet_paths
(
old_list
,
n_shave_prefix_segments
=
0
):
...
@@ -36,18 +38,18 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
...
@@ -36,18 +38,18 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
"""
mapping
=
[]
mapping
=
[]
for
old_item
in
old_list
:
for
old_item
in
old_list
:
new_item
=
old_item
.
replace
(
'
in_layers.0
'
,
'
norm1
'
)
new_item
=
old_item
.
replace
(
"
in_layers.0
"
,
"
norm1
"
)
new_item
=
new_item
.
replace
(
'
in_layers.2
'
,
'
conv1
'
)
new_item
=
new_item
.
replace
(
"
in_layers.2
"
,
"
conv1
"
)
new_item
=
new_item
.
replace
(
'
out_layers.0
'
,
'
norm2
'
)
new_item
=
new_item
.
replace
(
"
out_layers.0
"
,
"
norm2
"
)
new_item
=
new_item
.
replace
(
'
out_layers.3
'
,
'
conv2
'
)
new_item
=
new_item
.
replace
(
"
out_layers.3
"
,
"
conv2
"
)
new_item
=
new_item
.
replace
(
'
emb_layers.1
'
,
'
time_emb_proj
'
)
new_item
=
new_item
.
replace
(
"
emb_layers.1
"
,
"
time_emb_proj
"
)
new_item
=
new_item
.
replace
(
'
skip_connection
'
,
'
conv_shortcut
'
)
new_item
=
new_item
.
replace
(
"
skip_connection
"
,
"
conv_shortcut
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
return
mapping
...
@@ -60,20 +62,22 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
...
@@ -60,20 +62,22 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
for
old_item
in
old_list
:
for
old_item
in
old_list
:
new_item
=
old_item
new_item
=
old_item
new_item
=
new_item
.
replace
(
'
norm.weight
'
,
'
group_norm.weight
'
)
new_item
=
new_item
.
replace
(
"
norm.weight
"
,
"
group_norm.weight
"
)
new_item
=
new_item
.
replace
(
'
norm.bias
'
,
'
group_norm.bias
'
)
new_item
=
new_item
.
replace
(
"
norm.bias
"
,
"
group_norm.bias
"
)
new_item
=
new_item
.
replace
(
'
proj_out.weight
'
,
'
proj_attn.weight
'
)
new_item
=
new_item
.
replace
(
"
proj_out.weight
"
,
"
proj_attn.weight
"
)
new_item
=
new_item
.
replace
(
'
proj_out.bias
'
,
'
proj_attn.bias
'
)
new_item
=
new_item
.
replace
(
"
proj_out.bias
"
,
"
proj_attn.bias
"
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'
old
'
:
old_item
,
'
new
'
:
new_item
})
mapping
.
append
({
"
old
"
:
old_item
,
"
new
"
:
new_item
})
return
mapping
return
mapping
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
"""
"""
This does the final conversion step: take locally converted weights and apply a global renaming
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
to them. It splits attention layers, and takes into account additional replacements
...
@@ -96,31 +100,31 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
...
@@ -96,31 +100,31 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
checkpoint
[
path_map
[
'
query
'
]]
=
query
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
"
query
"
]]
=
query
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
'
key
'
]]
=
key
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
"
key
"
]]
=
key
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
'
value
'
]]
=
value
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
"
value
"
]]
=
value
.
reshape
(
target_shape
)
for
path
in
paths
:
for
path
in
paths
:
new_path
=
path
[
'
new
'
]
new_path
=
path
[
"
new
"
]
# These have already been assigned
# These have already been assigned
if
attention_paths_to_split
is
not
None
and
new_path
in
attention_paths_to_split
:
if
attention_paths_to_split
is
not
None
and
new_path
in
attention_paths_to_split
:
continue
continue
# Global renaming happens here
# Global renaming happens here
new_path
=
new_path
.
replace
(
'
middle_block.0
'
,
'
mid.resnets.0
'
)
new_path
=
new_path
.
replace
(
"
middle_block.0
"
,
"
mid.resnets.0
"
)
new_path
=
new_path
.
replace
(
'
middle_block.1
'
,
'
mid.attentions.0
'
)
new_path
=
new_path
.
replace
(
"
middle_block.1
"
,
"
mid.attentions.0
"
)
new_path
=
new_path
.
replace
(
'
middle_block.2
'
,
'
mid.resnets.1
'
)
new_path
=
new_path
.
replace
(
"
middle_block.2
"
,
"
mid.resnets.1
"
)
if
additional_replacements
is
not
None
:
if
additional_replacements
is
not
None
:
for
replacement
in
additional_replacements
:
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'
old
'
],
replacement
[
'
new
'
])
new_path
=
new_path
.
replace
(
replacement
[
"
old
"
],
replacement
[
"
new
"
])
# proj_attn.weight has to be converted from conv 1D to linear
# proj_attn.weight has to be converted from conv 1D to linear
if
"proj_attn.weight"
in
new_path
:
if
"proj_attn.weight"
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]][:,
:,
0
]
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]][:,
:,
0
]
else
:
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'
old
'
]]
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
"
old
"
]]
def
convert_ldm_checkpoint
(
checkpoint
,
config
):
def
convert_ldm_checkpoint
(
checkpoint
,
config
):
...
@@ -129,60 +133,78 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -129,60 +133,78 @@ def convert_ldm_checkpoint(checkpoint, config):
"""
"""
new_checkpoint
=
{}
new_checkpoint
=
{}
new_checkpoint
[
'
time_embedding.linear_1.weight
'
]
=
checkpoint
[
'
time_embed.0.weight
'
]
new_checkpoint
[
"
time_embedding.linear_1.weight
"
]
=
checkpoint
[
"
time_embed.0.weight
"
]
new_checkpoint
[
'
time_embedding.linear_1.bias
'
]
=
checkpoint
[
'
time_embed.0.bias
'
]
new_checkpoint
[
"
time_embedding.linear_1.bias
"
]
=
checkpoint
[
"
time_embed.0.bias
"
]
new_checkpoint
[
'
time_embedding.linear_2.weight
'
]
=
checkpoint
[
'
time_embed.2.weight
'
]
new_checkpoint
[
"
time_embedding.linear_2.weight
"
]
=
checkpoint
[
"
time_embed.2.weight
"
]
new_checkpoint
[
'
time_embedding.linear_2.bias
'
]
=
checkpoint
[
'
time_embed.2.bias
'
]
new_checkpoint
[
"
time_embedding.linear_2.bias
"
]
=
checkpoint
[
"
time_embed.2.bias
"
]
new_checkpoint
[
'
conv_in.weight
'
]
=
checkpoint
[
'
input_blocks.0.0.weight
'
]
new_checkpoint
[
"
conv_in.weight
"
]
=
checkpoint
[
"
input_blocks.0.0.weight
"
]
new_checkpoint
[
'
conv_in.bias
'
]
=
checkpoint
[
'
input_blocks.0.0.bias
'
]
new_checkpoint
[
"
conv_in.bias
"
]
=
checkpoint
[
"
input_blocks.0.0.bias
"
]
new_checkpoint
[
'
conv_norm_out.weight
'
]
=
checkpoint
[
'
out.0.weight
'
]
new_checkpoint
[
"
conv_norm_out.weight
"
]
=
checkpoint
[
"
out.0.weight
"
]
new_checkpoint
[
'
conv_norm_out.bias
'
]
=
checkpoint
[
'
out.0.bias
'
]
new_checkpoint
[
"
conv_norm_out.bias
"
]
=
checkpoint
[
"
out.0.bias
"
]
new_checkpoint
[
'
conv_out.weight
'
]
=
checkpoint
[
'
out.2.weight
'
]
new_checkpoint
[
"
conv_out.weight
"
]
=
checkpoint
[
"
out.2.weight
"
]
new_checkpoint
[
'
conv_out.bias
'
]
=
checkpoint
[
'
out.2.bias
'
]
new_checkpoint
[
"
conv_out.bias
"
]
=
checkpoint
[
"
out.2.bias
"
]
# Retrieves the keys for the input blocks only
# Retrieves the keys for the input blocks only
num_input_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'input_blocks'
in
layer
})
num_input_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"input_blocks"
in
layer
})
input_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'input_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_input_blocks
)}
input_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"input_blocks.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_input_blocks
)
}
# Retrieves the keys for the middle blocks only
# Retrieves the keys for the middle blocks only
num_middle_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'middle_block'
in
layer
})
num_middle_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"middle_block"
in
layer
})
middle_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'middle_block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_middle_blocks
)}
middle_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"middle_block.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_middle_blocks
)
}
# Retrieves the keys for the output blocks only
# Retrieves the keys for the output blocks only
num_output_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'output_blocks'
in
layer
})
num_output_blocks
=
len
({
"."
.
join
(
layer
.
split
(
"."
)[:
2
])
for
layer
in
checkpoint
if
"output_blocks"
in
layer
})
output_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'output_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_output_blocks
)}
output_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
"output_blocks.
{
layer_id
}
"
in
key
]
for
layer_id
in
range
(
num_output_blocks
)
}
for
i
in
range
(
1
,
num_input_blocks
):
for
i
in
range
(
1
,
num_input_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'
num_res_blocks
'
]
+
1
)
block_id
=
(
i
-
1
)
//
(
config
[
"
num_res_blocks
"
]
+
1
)
layer_in_block_id
=
(
i
-
1
)
%
(
config
[
'
num_res_blocks
'
]
+
1
)
layer_in_block_id
=
(
i
-
1
)
%
(
config
[
"
num_res_blocks
"
]
+
1
)
resnets
=
[
key
for
key
in
input_blocks
[
i
]
if
f
'
input_blocks.
{
i
}
.0
'
in
key
]
resnets
=
[
key
for
key
in
input_blocks
[
i
]
if
f
"
input_blocks.
{
i
}
.0
"
in
key
]
attentions
=
[
key
for
key
in
input_blocks
[
i
]
if
f
'
input_blocks.
{
i
}
.1
'
in
key
]
attentions
=
[
key
for
key
in
input_blocks
[
i
]
if
f
"
input_blocks.
{
i
}
.1
"
in
key
]
if
f
'input_blocks.
{
i
}
.0.op.weight'
in
checkpoint
:
if
f
"input_blocks.
{
i
}
.0.op.weight"
in
checkpoint
:
new_checkpoint
[
f
'downsample_blocks.
{
block_id
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'input_blocks.
{
i
}
.0.op.weight'
]
new_checkpoint
[
f
"downsample_blocks.
{
block_id
}
.downsamplers.0.conv.weight"
]
=
checkpoint
[
new_checkpoint
[
f
'downsample_blocks.
{
block_id
}
.downsamplers.0.conv.bias'
]
=
checkpoint
[
f
'input_blocks.
{
i
}
.0.op.bias'
]
f
"input_blocks.
{
i
}
.0.op.weight"
]
new_checkpoint
[
f
"downsample_blocks.
{
block_id
}
.downsamplers.0.conv.bias"
]
=
checkpoint
[
f
"input_blocks.
{
i
}
.0.op.bias"
]
paths
=
renew_resnet_paths
(
resnets
)
paths
=
renew_resnet_paths
(
resnets
)
meta_path
=
{
'old'
:
f
'input_blocks.
{
i
}
.0'
,
'new'
:
f
'downsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
meta_path
=
{
"old"
:
f
"input_blocks.
{
i
}
.0"
,
"new"
:
f
"downsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
"
}
resnet_op
=
{
'old'
:
'resnets.2.op'
,
'new'
:
'downsamplers.0.op'
}
resnet_op
=
{
"old"
:
"resnets.2.op"
,
"new"
:
"downsamplers.0.op"
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
,
resnet_op
],
config
=
config
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
,
resnet_op
],
config
=
config
)
if
len
(
attentions
):
if
len
(
attentions
):
paths
=
renew_attention_paths
(
attentions
)
paths
=
renew_attention_paths
(
attentions
)
meta_path
=
{
'old'
:
f
'input_blocks.
{
i
}
.1'
,
'new'
:
f
'downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
'
}
meta_path
=
{
"old"
:
f
"input_blocks.
{
i
}
.1"
,
"new"
:
f
"downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
"
,
}
to_split
=
{
to_split
=
{
f
'
input_blocks.
{
i
}
.1.qkv.bias
'
:
{
f
"
input_blocks.
{
i
}
.1.qkv.bias
"
:
{
'
key
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
'
,
"
key
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
"
,
'
query
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
'
,
"
query
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
"
,
'
value
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
'
,
"
value
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
"
,
},
},
f
'
input_blocks.
{
i
}
.1.qkv.weight
'
:
{
f
"
input_blocks.
{
i
}
.1.qkv.weight
"
:
{
'
key
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
'
,
"
key
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
"
,
'
query
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
'
,
"
query
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
"
,
'
value
'
:
f
'
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
'
,
"
value
"
:
f
"
downsample_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
"
,
},
},
}
}
assign_to_checkpoint
(
assign_to_checkpoint
(
...
@@ -191,7 +213,7 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -191,7 +213,7 @@ def convert_ldm_checkpoint(checkpoint, config):
checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
additional_replacements
=
[
meta_path
],
attention_paths_to_split
=
to_split
,
attention_paths_to_split
=
to_split
,
config
=
config
config
=
config
,
)
)
resnet_0
=
middle_blocks
[
0
]
resnet_0
=
middle_blocks
[
0
]
...
@@ -206,46 +228,52 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -206,46 +228,52 @@ def convert_ldm_checkpoint(checkpoint, config):
attentions_paths
=
renew_attention_paths
(
attentions
)
attentions_paths
=
renew_attention_paths
(
attentions
)
to_split
=
{
to_split
=
{
'
middle_block.1.qkv.bias
'
:
{
"
middle_block.1.qkv.bias
"
:
{
'
key
'
:
'
mid_block.attentions.0.key.bias
'
,
"
key
"
:
"
mid_block.attentions.0.key.bias
"
,
'
query
'
:
'
mid_block.attentions.0.query.bias
'
,
"
query
"
:
"
mid_block.attentions.0.query.bias
"
,
'
value
'
:
'
mid_block.attentions.0.value.bias
'
,
"
value
"
:
"
mid_block.attentions.0.value.bias
"
,
},
},
'
middle_block.1.qkv.weight
'
:
{
"
middle_block.1.qkv.weight
"
:
{
'
key
'
:
'
mid_block.attentions.0.key.weight
'
,
"
key
"
:
"
mid_block.attentions.0.key.weight
"
,
'
query
'
:
'
mid_block.attentions.0.query.weight
'
,
"
query
"
:
"
mid_block.attentions.0.query.weight
"
,
'
value
'
:
'
mid_block.attentions.0.value.weight
'
,
"
value
"
:
"
mid_block.attentions.0.value.weight
"
,
},
},
}
}
assign_to_checkpoint
(
attentions_paths
,
new_checkpoint
,
checkpoint
,
attention_paths_to_split
=
to_split
,
config
=
config
)
assign_to_checkpoint
(
attentions_paths
,
new_checkpoint
,
checkpoint
,
attention_paths_to_split
=
to_split
,
config
=
config
)
for
i
in
range
(
num_output_blocks
):
for
i
in
range
(
num_output_blocks
):
block_id
=
i
//
(
config
[
'
num_res_blocks
'
]
+
1
)
block_id
=
i
//
(
config
[
"
num_res_blocks
"
]
+
1
)
layer_in_block_id
=
i
%
(
config
[
'
num_res_blocks
'
]
+
1
)
layer_in_block_id
=
i
%
(
config
[
"
num_res_blocks
"
]
+
1
)
output_block_layers
=
[
shave_segments
(
name
,
2
)
for
name
in
output_blocks
[
i
]]
output_block_layers
=
[
shave_segments
(
name
,
2
)
for
name
in
output_blocks
[
i
]]
output_block_list
=
{}
output_block_list
=
{}
for
layer
in
output_block_layers
:
for
layer
in
output_block_layers
:
layer_id
,
layer_name
=
layer
.
split
(
'.'
)[
0
],
shave_segments
(
layer
,
1
)
layer_id
,
layer_name
=
layer
.
split
(
"."
)[
0
],
shave_segments
(
layer
,
1
)
if
layer_id
in
output_block_list
:
if
layer_id
in
output_block_list
:
output_block_list
[
layer_id
].
append
(
layer_name
)
output_block_list
[
layer_id
].
append
(
layer_name
)
else
:
else
:
output_block_list
[
layer_id
]
=
[
layer_name
]
output_block_list
[
layer_id
]
=
[
layer_name
]
if
len
(
output_block_list
)
>
1
:
if
len
(
output_block_list
)
>
1
:
resnets
=
[
key
for
key
in
output_blocks
[
i
]
if
f
'
output_blocks.
{
i
}
.0
'
in
key
]
resnets
=
[
key
for
key
in
output_blocks
[
i
]
if
f
"
output_blocks.
{
i
}
.0
"
in
key
]
attentions
=
[
key
for
key
in
output_blocks
[
i
]
if
f
'
output_blocks.
{
i
}
.1
'
in
key
]
attentions
=
[
key
for
key
in
output_blocks
[
i
]
if
f
"
output_blocks.
{
i
}
.1
"
in
key
]
resnet_0_paths
=
renew_resnet_paths
(
resnets
)
resnet_0_paths
=
renew_resnet_paths
(
resnets
)
paths
=
renew_resnet_paths
(
resnets
)
paths
=
renew_resnet_paths
(
resnets
)
meta_path
=
{
'
old
'
:
f
'
output_blocks.
{
i
}
.0
'
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
meta_path
=
{
"
old
"
:
f
"
output_blocks.
{
i
}
.0
"
,
"
new
"
:
f
"
up_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
"
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
config
=
config
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
config
=
config
)
if
[
'conv.weight'
,
'conv.bias'
]
in
output_block_list
.
values
():
if
[
"conv.weight"
,
"conv.bias"
]
in
output_block_list
.
values
():
index
=
list
(
output_block_list
.
values
()).
index
([
'conv.weight'
,
'conv.bias'
])
index
=
list
(
output_block_list
.
values
()).
index
([
"conv.weight"
,
"conv.bias"
])
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.weight'
]
=
checkpoint
[
f
'output_blocks.
{
i
}
.
{
index
}
.conv.weight'
]
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.weight"
]
=
checkpoint
[
new_checkpoint
[
f
'up_blocks.
{
block_id
}
.upsamplers.0.conv.bias'
]
=
checkpoint
[
f
'output_blocks.
{
i
}
.
{
index
}
.conv.bias'
]
f
"output_blocks.
{
i
}
.
{
index
}
.conv.weight"
]
new_checkpoint
[
f
"up_blocks.
{
block_id
}
.upsamplers.0.conv.bias"
]
=
checkpoint
[
f
"output_blocks.
{
i
}
.
{
index
}
.conv.bias"
]
# Clear attentions as they have been attributed above.
# Clear attentions as they have been attributed above.
if
len
(
attentions
)
==
2
:
if
len
(
attentions
)
==
2
:
...
@@ -254,19 +282,19 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -254,19 +282,19 @@ def convert_ldm_checkpoint(checkpoint, config):
if
len
(
attentions
):
if
len
(
attentions
):
paths
=
renew_attention_paths
(
attentions
)
paths
=
renew_attention_paths
(
attentions
)
meta_path
=
{
meta_path
=
{
'
old
'
:
f
'
output_blocks.
{
i
}
.1
'
,
"
old
"
:
f
"
output_blocks.
{
i
}
.1
"
,
'
new
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
'
"
new
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
"
,
}
}
to_split
=
{
to_split
=
{
f
'
output_blocks.
{
i
}
.1.qkv.bias
'
:
{
f
"
output_blocks.
{
i
}
.1.qkv.bias
"
:
{
'
key
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
'
,
"
key
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.bias
"
,
'
query
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
'
,
"
query
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.bias
"
,
'
value
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
'
,
"
value
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.bias
"
,
},
},
f
'
output_blocks.
{
i
}
.1.qkv.weight
'
:
{
f
"
output_blocks.
{
i
}
.1.qkv.weight
"
:
{
'
key
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
'
,
"
key
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.key.weight
"
,
'
query
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
'
,
"
query
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.query.weight
"
,
'
value
'
:
f
'
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
'
,
"
value
"
:
f
"
up_blocks.
{
block_id
}
.attentions.
{
layer_in_block_id
}
.value.weight
"
,
},
},
}
}
assign_to_checkpoint
(
assign_to_checkpoint
(
...
@@ -274,14 +302,14 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -274,14 +302,14 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint
,
new_checkpoint
,
checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
additional_replacements
=
[
meta_path
],
attention_paths_to_split
=
to_split
if
any
(
'
qkv
'
in
key
for
key
in
attentions
)
else
None
,
attention_paths_to_split
=
to_split
if
any
(
"
qkv
"
in
key
for
key
in
attentions
)
else
None
,
config
=
config
,
config
=
config
,
)
)
else
:
else
:
resnet_0_paths
=
renew_resnet_paths
(
output_block_layers
,
n_shave_prefix_segments
=
1
)
resnet_0_paths
=
renew_resnet_paths
(
output_block_layers
,
n_shave_prefix_segments
=
1
)
for
path
in
resnet_0_paths
:
for
path
in
resnet_0_paths
:
old_path
=
'.'
.
join
([
'
output_blocks
'
,
str
(
i
),
path
[
'
old
'
]])
old_path
=
"."
.
join
([
"
output_blocks
"
,
str
(
i
),
path
[
"
old
"
]])
new_path
=
'.'
.
join
([
'
up_blocks
'
,
str
(
block_id
),
'
resnets
'
,
str
(
layer_in_block_id
),
path
[
'
new
'
]])
new_path
=
"."
.
join
([
"
up_blocks
"
,
str
(
block_id
),
"
resnets
"
,
str
(
layer_in_block_id
),
path
[
"
new
"
]])
new_checkpoint
[
new_path
]
=
checkpoint
[
old_path
]
new_checkpoint
[
new_path
]
=
checkpoint
[
old_path
]
...
@@ -303,9 +331,7 @@ if __name__ == "__main__":
...
@@ -303,9 +331,7 @@ if __name__ == "__main__":
help
=
"The config json file corresponding to the architecture."
,
help
=
"The config json file corresponding to the architecture."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
View file @
89793a97
...
@@ -16,8 +16,10 @@
...
@@ -16,8 +16,10 @@
import
argparse
import
argparse
import
json
import
json
import
torch
import
torch
from
diffusers
import
UNet2DModel
from
diffusers
import
ScoreSdeVePipeline
,
ScoreSdeVeScheduler
,
UNet2DModel
def
convert_ncsnpp_checkpoint
(
checkpoint
,
config
):
def
convert_ncsnpp_checkpoint
(
checkpoint
,
config
):
...
...
scripts/generate_logits.py
View file @
89793a97
from
huggingface_hub
import
HfApi
from
transformers.file_utils
import
has_file
from
diffusers
import
UNet2DModel
import
random
import
random
import
torch
import
torch
from
diffusers
import
UNet2DModel
from
huggingface_hub
import
HfApi
api
=
HfApi
()
api
=
HfApi
()
results
=
{}
results
=
{}
results
[
"google_ddpm_cifar10_32"
]
=
torch
.
tensor
([
-
0.7515
,
-
1.6883
,
0.2420
,
0.0300
,
0.6347
,
1.3433
,
-
1.1743
,
-
3.7467
,
# fmt: off
1.2342
,
-
2.2485
,
0.4636
,
0.8076
,
-
0.7991
,
0.3969
,
0.8498
,
0.9189
,
results
[
"google_ddpm_cifar10_32"
]
=
torch
.
tensor
([
-
1.8887
,
-
3.3522
,
0.7639
,
0.2040
,
0.6271
,
-
2.7148
,
-
1.6316
,
3.0839
,
-
0.7515
,
-
1.6883
,
0.2420
,
0.0300
,
0.6347
,
1.3433
,
-
1.1743
,
-
3.7467
,
0.3186
,
0.2721
,
-
0.9759
,
-
1.2461
,
2.6257
,
1.3557
])
1.2342
,
-
2.2485
,
0.4636
,
0.8076
,
-
0.7991
,
0.3969
,
0.8498
,
0.9189
,
results
[
"google_ddpm_ema_bedroom_256"
]
=
torch
.
tensor
([
-
2.3639
,
-
2.5344
,
0.0054
,
-
0.6674
,
1.5990
,
1.0158
,
0.3124
,
-
2.1436
,
-
1.8887
,
-
3.3522
,
0.7639
,
0.2040
,
0.6271
,
-
2.7148
,
-
1.6316
,
3.0839
,
1.8795
,
-
2.5429
,
-
0.1566
,
-
0.3973
,
1.2490
,
2.6447
,
1.2283
,
-
0.5208
,
0.3186
,
0.2721
,
-
0.9759
,
-
1.2461
,
2.6257
,
1.3557
-
2.8154
,
-
3.5119
,
2.3838
,
1.2033
,
1.7201
,
-
2.1256
,
-
1.4576
,
2.7948
,
])
2.4204
,
-
0.9752
,
-
1.2546
,
0.8027
,
3.2758
,
3.1365
])
results
[
"google_ddpm_ema_bedroom_256"
]
=
torch
.
tensor
([
results
[
"CompVis_ldm_celebahq_256"
]
=
torch
.
tensor
([
-
0.6531
,
-
0.6891
,
-
0.3172
,
-
0.5375
,
-
0.9140
,
-
0.5367
,
-
0.1175
,
-
0.7869
,
-
2.3639
,
-
2.5344
,
0.0054
,
-
0.6674
,
1.5990
,
1.0158
,
0.3124
,
-
2.1436
,
-
0.3808
,
-
0.4513
,
-
0.2098
,
-
0.0083
,
0.3183
,
0.5140
,
0.2247
,
-
0.1304
,
1.8795
,
-
2.5429
,
-
0.1566
,
-
0.3973
,
1.2490
,
2.6447
,
1.2283
,
-
0.5208
,
-
0.1302
,
-
0.2802
,
-
0.2084
,
-
0.2025
,
-
0.4967
,
-
0.4873
,
-
0.0861
,
0.6925
,
-
2.8154
,
-
3.5119
,
2.3838
,
1.2033
,
1.7201
,
-
2.1256
,
-
1.4576
,
2.7948
,
0.0250
,
0.1290
,
-
0.1543
,
0.6316
,
1.0460
,
1.4943
])
2.4204
,
-
0.9752
,
-
1.2546
,
0.8027
,
3.2758
,
3.1365
results
[
"google_ncsnpp_ffhq_1024"
]
=
torch
.
tensor
([
0.0911
,
0.1107
,
0.0182
,
0.0435
,
-
0.0805
,
-
0.0608
,
0.0381
,
0.2172
,
])
-
0.0280
,
0.1327
,
-
0.0299
,
-
0.0255
,
-
0.0050
,
-
0.1170
,
-
0.1046
,
0.0309
,
results
[
"CompVis_ldm_celebahq_256"
]
=
torch
.
tensor
([
0.1367
,
0.1728
,
-
0.0533
,
-
0.0748
,
-
0.0534
,
0.1624
,
0.0384
,
-
0.1805
,
-
0.6531
,
-
0.6891
,
-
0.3172
,
-
0.5375
,
-
0.9140
,
-
0.5367
,
-
0.1175
,
-
0.7869
,
-
0.0707
,
0.0642
,
0.0220
,
-
0.0134
,
-
0.1333
,
-
0.1505
])
-
0.3808
,
-
0.4513
,
-
0.2098
,
-
0.0083
,
0.3183
,
0.5140
,
0.2247
,
-
0.1304
,
results
[
"google_ncsnpp_bedroom_256"
]
=
torch
.
tensor
([
0.1321
,
0.1337
,
0.0440
,
0.0622
,
-
0.0591
,
-
0.0370
,
0.0503
,
0.2133
,
-
0.1302
,
-
0.2802
,
-
0.2084
,
-
0.2025
,
-
0.4967
,
-
0.4873
,
-
0.0861
,
0.6925
,
-
0.0177
,
0.1415
,
-
0.0116
,
-
0.0112
,
0.0044
,
-
0.0980
,
-
0.0789
,
0.0395
,
0.0250
,
0.1290
,
-
0.1543
,
0.6316
,
1.0460
,
1.4943
0.1502
,
0.1785
,
-
0.0488
,
-
0.0514
,
-
0.0404
,
0.1539
,
0.0454
,
-
0.1559
,
])
-
0.0665
,
0.0659
,
0.0383
,
-
0.0005
,
-
0.1266
,
-
0.1386
])
results
[
"google_ncsnpp_ffhq_1024"
]
=
torch
.
tensor
([
results
[
"google_ncsnpp_celebahq_256"
]
=
torch
.
tensor
([
0.1154
,
0.1218
,
0.0307
,
0.0526
,
-
0.0711
,
-
0.0541
,
0.0366
,
0.2078
,
0.0911
,
0.1107
,
0.0182
,
0.0435
,
-
0.0805
,
-
0.0608
,
0.0381
,
0.2172
,
-
0.0267
,
0.1317
,
-
0.0226
,
-
0.0193
,
-
0.0014
,
-
0.1055
,
-
0.0902
,
0.0330
,
-
0.0280
,
0.1327
,
-
0.0299
,
-
0.0255
,
-
0.0050
,
-
0.1170
,
-
0.1046
,
0.0309
,
0.1391
,
0.1709
,
-
0.0562
,
-
0.0693
,
-
0.0560
,
0.1482
,
0.0381
,
-
0.1683
,
0.1367
,
0.1728
,
-
0.0533
,
-
0.0748
,
-
0.0534
,
0.1624
,
0.0384
,
-
0.1805
,
-
0.0681
,
0.0661
,
0.0331
,
-
0.0046
,
-
0.1268
,
-
0.1431
])
-
0.0707
,
0.0642
,
0.0220
,
-
0.0134
,
-
0.1333
,
-
0.1505
results
[
"google_ncsnpp_church_256"
]
=
torch
.
tensor
([
0.1192
,
0.1240
,
0.0414
,
0.0606
,
-
0.0557
,
-
0.0412
,
0.0430
,
0.2042
,
])
-
0.0200
,
0.1385
,
-
0.0115
,
-
0.0132
,
0.0017
,
-
0.0965
,
-
0.0802
,
0.0398
,
results
[
"google_ncsnpp_bedroom_256"
]
=
torch
.
tensor
([
0.1433
,
0.1747
,
-
0.0458
,
-
0.0533
,
-
0.0407
,
0.1545
,
0.0419
,
-
0.1574
,
0.1321
,
0.1337
,
0.0440
,
0.0622
,
-
0.0591
,
-
0.0370
,
0.0503
,
0.2133
,
-
0.0645
,
0.0626
,
0.0341
,
-
0.0010
,
-
0.1199
,
-
0.1390
])
-
0.0177
,
0.1415
,
-
0.0116
,
-
0.0112
,
0.0044
,
-
0.0980
,
-
0.0789
,
0.0395
,
results
[
"google_ncsnpp_ffhq_256"
]
=
torch
.
tensor
([
0.1075
,
0.1074
,
0.0205
,
0.0431
,
-
0.0774
,
-
0.0607
,
0.0298
,
0.2042
,
0.1502
,
0.1785
,
-
0.0488
,
-
0.0514
,
-
0.0404
,
0.1539
,
0.0454
,
-
0.1559
,
-
0.0320
,
0.1267
,
-
0.0281
,
-
0.0250
,
-
0.0064
,
-
0.1091
,
-
0.0946
,
0.0290
,
-
0.0665
,
0.0659
,
0.0383
,
-
0.0005
,
-
0.1266
,
-
0.1386
0.1328
,
0.1650
,
-
0.0580
,
-
0.0738
,
-
0.0586
,
0.1440
,
0.0337
,
-
0.1746
,
])
-
0.0712
,
0.0605
,
0.0250
,
-
0.0099
,
-
0.1316
,
-
0.1473
])
results
[
"google_ncsnpp_celebahq_256"
]
=
torch
.
tensor
([
results
[
"google_ddpm_cat_256"
]
=
torch
.
tensor
([
-
1.4572
,
-
2.0481
,
-
0.0414
,
-
0.6005
,
1.4136
,
0.5848
,
0.4028
,
-
2.7330
,
0.1154
,
0.1218
,
0.0307
,
0.0526
,
-
0.0711
,
-
0.0541
,
0.0366
,
0.2078
,
1.2212
,
-
2.1228
,
0.2155
,
0.4039
,
0.7662
,
2.0535
,
0.7477
,
-
0.3243
,
-
0.0267
,
0.1317
,
-
0.0226
,
-
0.0193
,
-
0.0014
,
-
0.1055
,
-
0.0902
,
0.0330
,
-
2.1758
,
-
2.7648
,
1.6947
,
0.7026
,
1.2338
,
-
1.6078
,
-
0.8682
,
2.2810
,
0.1391
,
0.1709
,
-
0.0562
,
-
0.0693
,
-
0.0560
,
0.1482
,
0.0381
,
-
0.1683
,
1.8574
,
-
0.5718
,
-
0.5586
,
-
0.0186
,
2.3415
,
2.1251
])
-
0.0681
,
0.0661
,
0.0331
,
-
0.0046
,
-
0.1268
,
-
0.1431
results
[
"google_ddpm_celebahq_256"
]
=
torch
.
tensor
([
-
1.3690
,
-
1.9720
,
-
0.4090
,
-
0.6966
,
1.4660
,
0.9938
,
-
0.1385
,
-
2.7324
,
])
0.7736
,
-
1.8917
,
0.2923
,
0.4293
,
0.1693
,
1.4112
,
1.1887
,
-
0.3181
,
results
[
"google_ncsnpp_church_256"
]
=
torch
.
tensor
([
-
2.2160
,
-
2.6381
,
1.3170
,
0.8163
,
0.9240
,
-
1.6544
,
-
0.6099
,
2.5259
,
0.1192
,
0.1240
,
0.0414
,
0.0606
,
-
0.0557
,
-
0.0412
,
0.0430
,
0.2042
,
1.6430
,
-
0.9090
,
-
0.9392
,
-
0.0126
,
2.4268
,
2.3266
])
-
0.0200
,
0.1385
,
-
0.0115
,
-
0.0132
,
0.0017
,
-
0.0965
,
-
0.0802
,
0.0398
,
results
[
"google_ddpm_ema_celebahq_256"
]
=
torch
.
tensor
([
-
1.3525
,
-
1.9628
,
-
0.3956
,
-
0.6860
,
1.4664
,
1.0014
,
-
0.1259
,
-
2.7212
,
0.1433
,
0.1747
,
-
0.0458
,
-
0.0533
,
-
0.0407
,
0.1545
,
0.0419
,
-
0.1574
,
0.7772
,
-
1.8811
,
0.2996
,
0.4388
,
0.1704
,
1.4029
,
1.1701
,
-
0.3027
,
-
0.0645
,
0.0626
,
0.0341
,
-
0.0010
,
-
0.1199
,
-
0.1390
-
2.2053
,
-
2.6287
,
1.3350
,
0.8131
,
0.9274
,
-
1.6292
,
-
0.6098
,
2.5131
,
])
1.6505
,
-
0.8958
,
-
0.9298
,
-
0.0151
,
2.4257
,
2.3355
])
results
[
"google_ncsnpp_ffhq_256"
]
=
torch
.
tensor
([
results
[
"google_ddpm_church_256"
]
=
torch
.
tensor
([
-
2.0585
,
-
2.7897
,
-
0.2850
,
-
0.8940
,
1.9052
,
0.5702
,
0.6345
,
-
3.8959
,
0.1075
,
0.1074
,
0.0205
,
0.0431
,
-
0.0774
,
-
0.0607
,
0.0298
,
0.2042
,
1.5932
,
-
3.2319
,
0.1974
,
0.0287
,
1.7566
,
2.6543
,
0.8387
,
-
0.5351
,
-
0.0320
,
0.1267
,
-
0.0281
,
-
0.0250
,
-
0.0064
,
-
0.1091
,
-
0.0946
,
0.0290
,
-
3.2736
,
-
4.3375
,
2.9029
,
1.6390
,
1.4640
,
-
2.1701
,
-
1.9013
,
2.9341
,
0.1328
,
0.1650
,
-
0.0580
,
-
0.0738
,
-
0.0586
,
0.1440
,
0.0337
,
-
0.1746
,
3.4981
,
-
0.6255
,
-
1.1644
,
-
0.1591
,
3.7097
,
3.2066
])
-
0.0712
,
0.0605
,
0.0250
,
-
0.0099
,
-
0.1316
,
-
0.1473
results
[
"google_ddpm_bedroom_256"
]
=
torch
.
tensor
([
-
2.3139
,
-
2.5594
,
-
0.0197
,
-
0.6785
,
1.7001
,
1.1606
,
0.3075
,
-
2.1740
,
])
1.8071
,
-
2.5630
,
-
0.0926
,
-
0.3811
,
1.2116
,
2.6246
,
1.2731
,
-
0.5398
,
results
[
"google_ddpm_cat_256"
]
=
torch
.
tensor
([
-
2.8153
,
-
3.6140
,
2.3893
,
1.3262
,
1.6258
,
-
2.1856
,
-
1.3267
,
2.8395
,
-
1.4572
,
-
2.0481
,
-
0.0414
,
-
0.6005
,
1.4136
,
0.5848
,
0.4028
,
-
2.7330
,
2.3779
,
-
1.0623
,
-
1.2468
,
0.8959
,
3.3367
,
3.2243
])
1.2212
,
-
2.1228
,
0.2155
,
0.4039
,
0.7662
,
2.0535
,
0.7477
,
-
0.3243
,
results
[
"google_ddpm_ema_church_256"
]
=
torch
.
tensor
([
-
2.0628
,
-
2.7667
,
-
0.2089
,
-
0.8263
,
2.0539
,
0.5992
,
0.6495
,
-
3.8336
,
-
2.1758
,
-
2.7648
,
1.6947
,
0.7026
,
1.2338
,
-
1.6078
,
-
0.8682
,
2.2810
,
1.6025
,
-
3.2817
,
0.1721
,
-
0.0633
,
1.7516
,
2.7039
,
0.8100
,
-
0.5908
,
1.8574
,
-
0.5718
,
-
0.5586
,
-
0.0186
,
2.3415
,
2.1251
])
-
3.2113
,
-
4.4343
,
2.9257
,
1.3632
,
1.5562
,
-
2.1489
,
-
1.9894
,
3.0560
,
results
[
"google_ddpm_celebahq_256"
]
=
torch
.
tensor
([
3.3396
,
-
0.7328
,
-
1.0417
,
0.0383
,
3.7093
,
3.2343
])
-
1.3690
,
-
1.9720
,
-
0.4090
,
-
0.6966
,
1.4660
,
0.9938
,
-
0.1385
,
-
2.7324
,
results
[
"google_ddpm_ema_cat_256"
]
=
torch
.
tensor
([
-
1.4574
,
-
2.0569
,
-
0.0473
,
-
0.6117
,
1.4018
,
0.5769
,
0.4129
,
-
2.7344
,
0.7736
,
-
1.8917
,
0.2923
,
0.4293
,
0.1693
,
1.4112
,
1.1887
,
-
0.3181
,
1.2241
,
-
2.1397
,
0.2000
,
0.3937
,
0.7616
,
2.0453
,
0.7324
,
-
0.3391
,
-
2.2160
,
-
2.6381
,
1.3170
,
0.8163
,
0.9240
,
-
1.6544
,
-
0.6099
,
2.5259
,
-
2.1746
,
-
2.7744
,
1.6963
,
0.6921
,
1.2187
,
-
1.6172
,
-
0.8877
,
2.2439
,
1.6430
,
-
0.9090
,
-
0.9392
,
-
0.0126
,
2.4268
,
2.3266
1.8471
,
-
0.5839
,
-
0.5605
,
-
0.0464
,
2.3250
,
2.1219
])
])
results
[
"google_ddpm_ema_celebahq_256"
]
=
torch
.
tensor
([
-
1.3525
,
-
1.9628
,
-
0.3956
,
-
0.6860
,
1.4664
,
1.0014
,
-
0.1259
,
-
2.7212
,
0.7772
,
-
1.8811
,
0.2996
,
0.4388
,
0.1704
,
1.4029
,
1.1701
,
-
0.3027
,
-
2.2053
,
-
2.6287
,
1.3350
,
0.8131
,
0.9274
,
-
1.6292
,
-
0.6098
,
2.5131
,
1.6505
,
-
0.8958
,
-
0.9298
,
-
0.0151
,
2.4257
,
2.3355
])
results
[
"google_ddpm_church_256"
]
=
torch
.
tensor
([
-
2.0585
,
-
2.7897
,
-
0.2850
,
-
0.8940
,
1.9052
,
0.5702
,
0.6345
,
-
3.8959
,
1.5932
,
-
3.2319
,
0.1974
,
0.0287
,
1.7566
,
2.6543
,
0.8387
,
-
0.5351
,
-
3.2736
,
-
4.3375
,
2.9029
,
1.6390
,
1.4640
,
-
2.1701
,
-
1.9013
,
2.9341
,
3.4981
,
-
0.6255
,
-
1.1644
,
-
0.1591
,
3.7097
,
3.2066
])
results
[
"google_ddpm_bedroom_256"
]
=
torch
.
tensor
([
-
2.3139
,
-
2.5594
,
-
0.0197
,
-
0.6785
,
1.7001
,
1.1606
,
0.3075
,
-
2.1740
,
1.8071
,
-
2.5630
,
-
0.0926
,
-
0.3811
,
1.2116
,
2.6246
,
1.2731
,
-
0.5398
,
-
2.8153
,
-
3.6140
,
2.3893
,
1.3262
,
1.6258
,
-
2.1856
,
-
1.3267
,
2.8395
,
2.3779
,
-
1.0623
,
-
1.2468
,
0.8959
,
3.3367
,
3.2243
])
results
[
"google_ddpm_ema_church_256"
]
=
torch
.
tensor
([
-
2.0628
,
-
2.7667
,
-
0.2089
,
-
0.8263
,
2.0539
,
0.5992
,
0.6495
,
-
3.8336
,
1.6025
,
-
3.2817
,
0.1721
,
-
0.0633
,
1.7516
,
2.7039
,
0.8100
,
-
0.5908
,
-
3.2113
,
-
4.4343
,
2.9257
,
1.3632
,
1.5562
,
-
2.1489
,
-
1.9894
,
3.0560
,
3.3396
,
-
0.7328
,
-
1.0417
,
0.0383
,
3.7093
,
3.2343
])
results
[
"google_ddpm_ema_cat_256"
]
=
torch
.
tensor
([
-
1.4574
,
-
2.0569
,
-
0.0473
,
-
0.6117
,
1.4018
,
0.5769
,
0.4129
,
-
2.7344
,
1.2241
,
-
2.1397
,
0.2000
,
0.3937
,
0.7616
,
2.0453
,
0.7324
,
-
0.3391
,
-
2.1746
,
-
2.7744
,
1.6963
,
0.6921
,
1.2187
,
-
1.6172
,
-
0.8877
,
2.2439
,
1.8471
,
-
0.5839
,
-
0.5605
,
-
0.0464
,
2.3250
,
2.1219
])
# fmt: on
models
=
api
.
list_models
(
filter
=
"diffusers"
)
models
=
api
.
list_models
(
filter
=
"diffusers"
)
for
mod
in
models
:
for
mod
in
models
:
if
"google"
in
mod
.
author
or
mod
.
modelId
==
"CompVis/ldm-celebahq-256"
:
if
"google"
in
mod
.
author
or
mod
.
modelId
==
"CompVis/ldm-celebahq-256"
:
local_checkpoint
=
"/home/patrick/google_checkpoints/"
+
mod
.
modelId
.
split
(
"/"
)[
-
1
]
local_checkpoint
=
"/home/patrick/google_checkpoints/"
+
mod
.
modelId
.
split
(
"/"
)[
-
1
]
print
(
f
"Started running
{
mod
.
modelId
}
!!!"
)
print
(
f
"Started running
{
mod
.
modelId
}
!!!"
)
if
mod
.
modelId
.
startswith
(
"CompVis"
):
if
mod
.
modelId
.
startswith
(
"CompVis"
):
model
=
UNet2DModel
.
from_pretrained
(
local_checkpoint
,
subfolder
=
"unet"
)
model
=
UNet2DModel
.
from_pretrained
(
local_checkpoint
,
subfolder
=
"unet"
)
else
:
else
:
model
=
UNet2DModel
.
from_pretrained
(
local_checkpoint
)
model
=
UNet2DModel
.
from_pretrained
(
local_checkpoint
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
random
.
seed
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
sample_size
,
model
.
config
.
sample_size
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
sample_size
,
model
.
config
.
sample_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
with
torch
.
no_grad
():
logits
=
model
(
noise
,
time_step
)[
'
sample
'
]
logits
=
model
(
noise
,
time_step
)[
"
sample
"
]
assert
torch
.
allclose
(
logits
[
0
,
0
,
0
,
:
30
],
results
[
"_"
.
join
(
"_"
.
join
(
mod
.
modelId
.
split
(
"/"
)).
split
(
"-"
))],
atol
=
1e-3
)
assert
torch
.
allclose
(
logits
[
0
,
0
,
0
,
:
30
],
results
[
"_"
.
join
(
"_"
.
join
(
mod
.
modelId
.
split
(
"/"
)).
split
(
"-"
))],
atol
=
1e-3
)
print
(
f
"
{
mod
.
modelId
}
has passed succesfully!!!"
)
print
(
f
"
{
mod
.
modelId
}
has passed succesfully!!!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment