Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
3f0b44b3
Commit
3f0b44b3
authored
Jul 19, 2022
by
Patrick von Platen
Browse files
improve ddpm conversion script
parent
cb90fd69
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
11 deletions
+14
-11
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
+14
-11
No files found.
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
View file @
3f0b44b3
from
diffusers
import
UNetUnconditionalModel
from
diffusers
import
UNetUnconditionalModel
,
DDPMScheduler
,
DDPMPipeline
import
argparse
import
json
import
torch
...
...
@@ -56,7 +56,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
if
attention_paths_to_split
is
not
None
:
if
config
is
None
:
raise
ValueError
(
f
"Please specify the config if setting 'attention_paths_to_split' to 'True'."
)
raise
ValueError
(
"Please specify the config if setting 'attention_paths_to_split' to 'True'."
)
for
path
,
path_map
in
attention_paths_to_split
.
items
():
old_tensor
=
old_checkpoint
[
path
]
...
...
@@ -86,7 +86,6 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'old'
],
replacement
[
'new'
])
if
'attentions'
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]].
squeeze
()
else
:
...
...
@@ -97,7 +96,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
new_checkpoint
=
{}
new_checkpoint
[
'time_embedding.linear_1.weight'
]
=
checkpoint
[
'temb.dense.0.weight'
]
...
...
@@ -121,7 +119,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
for
i
in
range
(
num_downsample_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'num_res_blocks'
]
+
1
)
layer_in_block_id
=
(
i
-
1
)
%
(
config
[
'num_res_blocks'
]
+
1
)
if
any
(
'downsample'
in
layer
for
layer
in
downsample_blocks
[
i
]):
new_checkpoint
[
f
'downsample_blocks.
{
i
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.conv.weight'
]
...
...
@@ -138,7 +135,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
if
any
(
'attn'
in
layer
for
layer
in
downsample_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
downsample_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
downsample_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
...
...
@@ -148,7 +144,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
mid_block_1_layers
=
[
key
for
key
in
checkpoint
if
"mid.block_1"
in
key
]
mid_block_2_layers
=
[
key
for
key
in
checkpoint
if
"mid.block_2"
in
key
]
mid_attn_1_layers
=
[
key
for
key
in
checkpoint
if
"mid.attn_1"
in
key
]
...
...
@@ -186,7 +181,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
if
any
(
'attn'
in
layer
for
layer
in
upsample_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
upsample_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
upsample_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
...
...
@@ -220,12 +214,21 @@ if __name__ == "__main__":
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
with
open
(
args
.
config_file
)
as
f
:
config
=
json
.
loads
(
f
.
read
())
converted_checkpoint
=
convert_ddpm_checkpoint
(
args
.
checkpoint_path
,
args
.
config_file
)
torch
.
save
(
converted_checkpoint
,
args
.
dump_path
)
converted_checkpoint
=
convert_ddpm_checkpoint
(
checkpoint
,
config
)
if
"ddpm"
in
config
:
del
config
[
"ddpm"
]
model
=
UNetUnconditionalModel
(
**
config
)
model
.
load_state_dict
(
converted_checkpoint
)
scheduler
=
DDPMScheduler
.
from_config
(
"/"
.
join
(
args
.
checkpoint_path
.
split
(
"/"
)[:
-
1
]))
pipe
=
DDPMPipeline
(
unet
=
model
,
scheduler
=
scheduler
)
pipe
.
save_pretrained
(
args
.
dump_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment