Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3f0b44b3
Commit
3f0b44b3
authored
Jul 19, 2022
by
Patrick von Platen
Browse files
improve ddpm conversion script
parent
cb90fd69
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
11 deletions
+14
-11
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
+14
-11
No files found.
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
View file @
3f0b44b3
from
diffusers
import
UNetUnconditionalModel
from
diffusers
import
UNetUnconditionalModel
,
DDPMScheduler
,
DDPMPipeline
import
argparse
import
argparse
import
json
import
json
import
torch
import
torch
...
@@ -56,7 +56,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
...
@@ -56,7 +56,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
if
attention_paths_to_split
is
not
None
:
if
attention_paths_to_split
is
not
None
:
if
config
is
None
:
if
config
is
None
:
raise
ValueError
(
f
"Please specify the config if setting 'attention_paths_to_split' to 'True'."
)
raise
ValueError
(
"Please specify the config if setting 'attention_paths_to_split' to 'True'."
)
for
path
,
path_map
in
attention_paths_to_split
.
items
():
for
path
,
path_map
in
attention_paths_to_split
.
items
():
old_tensor
=
old_checkpoint
[
path
]
old_tensor
=
old_checkpoint
[
path
]
...
@@ -86,7 +86,6 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
...
@@ -86,7 +86,6 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
for
replacement
in
additional_replacements
:
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'old'
],
replacement
[
'new'
])
new_path
=
new_path
.
replace
(
replacement
[
'old'
],
replacement
[
'new'
])
if
'attentions'
in
new_path
:
if
'attentions'
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]].
squeeze
()
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]].
squeeze
()
else
:
else
:
...
@@ -97,7 +96,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
...
@@ -97,7 +96,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
"""
"""
Takes a state dict and a config, and returns a converted checkpoint.
Takes a state dict and a config, and returns a converted checkpoint.
"""
"""
new_checkpoint
=
{}
new_checkpoint
=
{}
new_checkpoint
[
'time_embedding.linear_1.weight'
]
=
checkpoint
[
'temb.dense.0.weight'
]
new_checkpoint
[
'time_embedding.linear_1.weight'
]
=
checkpoint
[
'temb.dense.0.weight'
]
...
@@ -121,7 +119,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
...
@@ -121,7 +119,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
for
i
in
range
(
num_downsample_blocks
):
for
i
in
range
(
num_downsample_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'num_res_blocks'
]
+
1
)
block_id
=
(
i
-
1
)
//
(
config
[
'num_res_blocks'
]
+
1
)
layer_in_block_id
=
(
i
-
1
)
%
(
config
[
'num_res_blocks'
]
+
1
)
if
any
(
'downsample'
in
layer
for
layer
in
downsample_blocks
[
i
]):
if
any
(
'downsample'
in
layer
for
layer
in
downsample_blocks
[
i
]):
new_checkpoint
[
f
'downsample_blocks.
{
i
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.conv.weight'
]
new_checkpoint
[
f
'downsample_blocks.
{
i
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.conv.weight'
]
...
@@ -138,7 +135,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
...
@@ -138,7 +135,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths
=
renew_resnet_paths
(
blocks
[
j
])
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
if
any
(
'attn'
in
layer
for
layer
in
downsample_blocks
[
i
]):
if
any
(
'attn'
in
layer
for
layer
in
downsample_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
downsample_blocks
[
i
]
if
'attn'
in
layer
})
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
downsample_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
downsample_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
attns
=
{
layer_id
:
[
key
for
key
in
downsample_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
...
@@ -148,7 +144,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
...
@@ -148,7 +144,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths
=
renew_attention_paths
(
attns
[
j
])
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
mid_block_1_layers
=
[
key
for
key
in
checkpoint
if
"mid.block_1"
in
key
]
mid_block_1_layers
=
[
key
for
key
in
checkpoint
if
"mid.block_1"
in
key
]
mid_block_2_layers
=
[
key
for
key
in
checkpoint
if
"mid.block_2"
in
key
]
mid_block_2_layers
=
[
key
for
key
in
checkpoint
if
"mid.block_2"
in
key
]
mid_attn_1_layers
=
[
key
for
key
in
checkpoint
if
"mid.attn_1"
in
key
]
mid_attn_1_layers
=
[
key
for
key
in
checkpoint
if
"mid.attn_1"
in
key
]
...
@@ -186,7 +181,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
...
@@ -186,7 +181,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths
=
renew_resnet_paths
(
blocks
[
j
])
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
if
any
(
'attn'
in
layer
for
layer
in
upsample_blocks
[
i
]):
if
any
(
'attn'
in
layer
for
layer
in
upsample_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
upsample_blocks
[
i
]
if
'attn'
in
layer
})
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
upsample_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
upsample_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
attns
=
{
layer_id
:
[
key
for
key
in
upsample_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
...
@@ -220,12 +214,21 @@ if __name__ == "__main__":
...
@@ -220,12 +214,21 @@ if __name__ == "__main__":
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
with
open
(
args
.
config_file
)
as
f
:
with
open
(
args
.
config_file
)
as
f
:
config
=
json
.
loads
(
f
.
read
())
config
=
json
.
loads
(
f
.
read
())
converted_checkpoint
=
convert_ddpm_checkpoint
(
args
.
checkpoint_path
,
args
.
config_file
)
converted_checkpoint
=
convert_ddpm_checkpoint
(
checkpoint
,
config
)
torch
.
save
(
converted_checkpoint
,
args
.
dump_path
)
if
"ddpm"
in
config
:
del
config
[
"ddpm"
]
model
=
UNetUnconditionalModel
(
**
config
)
model
.
load_state_dict
(
converted_checkpoint
)
scheduler
=
DDPMScheduler
.
from_config
(
"/"
.
join
(
args
.
checkpoint_path
.
split
(
"/"
)[:
-
1
]))
pipe
=
DDPMPipeline
(
unet
=
model
,
scheduler
=
scheduler
)
pipe
.
save_pretrained
(
args
.
dump_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment