Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
3f1e9592
Commit
3f1e9592
authored
Jul 15, 2022
by
Patrick von Platen
Browse files
Fix conversion script
parent
87060e6a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
113 additions
and
19 deletions
+113
-19
debug_conversion.py
debug_conversion.py
+86
-0
scripts/__init__.py
scripts/__init__.py
+0
-0
scripts/convert_ldm_original_checkpoint_to_diffusers.py
scripts/convert_ldm_original_checkpoint_to_diffusers.py
+27
-19
No files found.
debug_conversion.py
0 → 100755
View file @
3f1e9592
#!/usr/bin/env python3
import
json
import
os
from
diffusers
import
UNetUnconditionalModel
from
scripts.convert_ldm_original_checkpoint_to_diffusers
import
convert_ldm_checkpoint
from
huggingface_hub
import
hf_hub_download
import
torch
model_id
=
"fusing/latent-diffusion-celeba-256"
subfolder
=
"unet"
#model_id = "fusing/unet-ldm-dummy"
#subfolder = None
checkpoint
=
"diffusion_model.pt"
config
=
"config.json"
if
subfolder
is
not
None
:
checkpoint
=
os
.
path
.
join
(
subfolder
,
checkpoint
)
config
=
os
.
path
.
join
(
subfolder
,
config
)
original_checkpoint
=
torch
.
load
(
hf_hub_download
(
model_id
,
checkpoint
))
config_path
=
hf_hub_download
(
model_id
,
config
)
with
open
(
config_path
)
as
f
:
config
=
json
.
load
(
f
)
checkpoint
=
convert_ldm_checkpoint
(
original_checkpoint
,
config
)
def
current_codebase_conversion
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
subfolder
=
subfolder
,
ldm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
return
model
.
state_dict
()
currently_converted_checkpoint
=
current_codebase_conversion
()
torch
.
save
(
currently_converted_checkpoint
,
'currently_converted_checkpoint.pt'
)
def
diff_between_checkpoints
(
ch_0
,
ch_1
):
all_layers_included
=
False
if
not
set
(
ch_0
.
keys
())
==
set
(
ch_1
.
keys
()):
print
(
f
"Contained in ch_0 and not in ch_1 (Total:
{
len
((
set
(
ch_0
.
keys
())
-
set
(
ch_1
.
keys
())))
}
)"
)
for
key
in
sorted
(
list
((
set
(
ch_0
.
keys
())
-
set
(
ch_1
.
keys
())))):
print
(
f
"
\t
{
key
}
"
)
print
(
f
"Contained in ch_1 and not in ch_0 (Total:
{
len
((
set
(
ch_1
.
keys
())
-
set
(
ch_0
.
keys
())))
}
)"
)
for
key
in
sorted
(
list
((
set
(
ch_1
.
keys
())
-
set
(
ch_0
.
keys
())))):
print
(
f
"
\t
{
key
}
"
)
else
:
print
(
"Keys are the same between the two checkpoints"
)
all_layers_included
=
True
keys
=
ch_0
.
keys
()
non_equal_keys
=
[]
if
all_layers_included
:
for
key
in
keys
:
try
:
if
not
torch
.
allclose
(
ch_0
[
key
].
cpu
(),
ch_1
[
key
].
cpu
()):
non_equal_keys
.
append
(
f
'
{
key
}
. Diff:
{
torch
.
max
(
torch
.
abs
(
ch_0
[
key
].
cpu
()
-
ch_1
[
key
].
cpu
()))
}
'
)
except
RuntimeError
as
e
:
print
(
e
)
non_equal_keys
.
append
(
f
'
{
key
}
. Diff in shape:
{
ch_0
[
key
].
size
()
}
vs
{
ch_1
[
key
].
size
()
}
'
)
if
len
(
non_equal_keys
):
non_equal_keys
=
'
\n\t
'
.
join
(
non_equal_keys
)
print
(
f
"These keys do not satisfy equivalence requirement:
\n\t
{
non_equal_keys
}
"
)
else
:
print
(
"All keys are equal across checkpoints."
)
diff_between_checkpoints
(
currently_converted_checkpoint
,
checkpoint
)
scripts/__init__.py
0 → 100644
View file @
3f1e9592
scripts/convert_ldm_original_checkpoint_to_diffusers.py
View file @
3f1e9592
...
...
@@ -72,7 +72,7 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
return
mapping
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
):
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
...
...
@@ -85,11 +85,19 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
# Splits the attention layers into three variables.
if
attention_paths_to_split
is
not
None
:
for
path
,
path_map
in
attention_paths_to_split
.
items
():
query
,
key
,
value
=
torch
.
split
(
old_checkpoint
[
path
],
int
(
old_checkpoint
[
path
].
shape
[
0
]
/
3
))
old_tensor
=
old_checkpoint
[
path
]
channels
=
old_tensor
.
shape
[
0
]
//
3
checkpoint
[
path_map
[
'query'
]]
=
query
checkpoint
[
path_map
[
'key'
]]
=
key
checkpoint
[
path_map
[
'value'
]]
=
value
target_shape
=
(
-
1
,
channels
)
if
len
(
old_tensor
.
shape
)
==
3
else
(
-
1
)
num_heads
=
old_tensor
.
shape
[
0
]
//
config
[
"num_head_channels"
]
//
3
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
checkpoint
[
path_map
[
'query'
]]
=
query
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
'key'
]]
=
key
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
'value'
]]
=
value
.
reshape
(
target_shape
)
for
path
in
paths
:
new_path
=
path
[
'new'
]
...
...
@@ -107,7 +115,11 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'old'
],
replacement
[
'new'
])
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]]
# proj_attn.weight has to be converted from conv 1D to linear
if
"proj_attn.weight"
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]][:,
:,
0
]
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]]
def
convert_ldm_checkpoint
(
checkpoint
,
config
):
...
...
@@ -155,7 +167,7 @@ def convert_ldm_checkpoint(checkpoint, config):
paths
=
renew_resnet_paths
(
resnets
)
meta_path
=
{
'old'
:
f
'input_blocks.
{
i
}
.0'
,
'new'
:
f
'downsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
resnet_op
=
{
'old'
:
'resnets.2.op'
,
'new'
:
'downsamplers.0.op'
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
,
resnet_op
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
,
resnet_op
]
,
config
=
config
)
if
len
(
attentions
):
paths
=
renew_attention_paths
(
attentions
)
...
...
@@ -177,19 +189,19 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
attention_paths_to_split
=
to_split
attention_paths_to_split
=
to_split
,
config
=
config
)
resnet_0
=
middle_blocks
[
0
]
attentions
=
middle_blocks
[
1
]
resnet_1
=
middle_blocks
[
2
]
resnet_0_paths
=
renew_resnet_paths
(
resnet_0
)
assign_to_checkpoint
(
resnet_0_paths
,
new_checkpoint
,
checkpoint
)
assign_to_checkpoint
(
resnet_0_paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
resnet_1_paths
=
renew_resnet_paths
(
resnet_1
)
assign_to_checkpoint
(
resnet_1_paths
,
new_checkpoint
,
checkpoint
)
assign_to_checkpoint
(
resnet_1_paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
attentions_paths
=
renew_attention_paths
(
attentions
)
to_split
=
{
...
...
@@ -204,7 +216,7 @@ def convert_ldm_checkpoint(checkpoint, config):
'value'
:
'mid.attentions.0.value.weight'
,
},
}
assign_to_checkpoint
(
attentions_paths
,
new_checkpoint
,
checkpoint
,
attention_paths_to_split
=
to_split
)
assign_to_checkpoint
(
attentions_paths
,
new_checkpoint
,
checkpoint
,
attention_paths_to_split
=
to_split
,
config
=
config
)
for
i
in
range
(
num_output_blocks
):
block_id
=
i
//
(
config
[
'num_res_blocks'
]
+
1
)
...
...
@@ -227,7 +239,7 @@ def convert_ldm_checkpoint(checkpoint, config):
paths
=
renew_resnet_paths
(
resnets
)
meta_path
=
{
'old'
:
f
'output_blocks.
{
i
}
.0'
,
'new'
:
f
'upsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
]
,
config
=
config
)
if
[
'conv.weight'
,
'conv.bias'
]
in
output_block_list
.
values
():
index
=
list
(
output_block_list
.
values
()).
index
([
'conv.weight'
,
'conv.bias'
])
...
...
@@ -238,7 +250,6 @@ def convert_ldm_checkpoint(checkpoint, config):
if
len
(
attentions
)
==
2
:
attentions
=
[]
if
len
(
attentions
):
paths
=
renew_attention_paths
(
attentions
)
meta_path
=
{
...
...
@@ -262,7 +273,8 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
attention_paths_to_split
=
to_split
if
any
(
'qkv'
in
key
for
key
in
attentions
)
else
None
attention_paths_to_split
=
to_split
if
any
(
'qkv'
in
key
for
key
in
attentions
)
else
None
,
config
=
config
,
)
else
:
resnet_0_paths
=
renew_resnet_paths
(
output_block_layers
,
n_shave_prefix_segments
=
1
)
...
...
@@ -296,7 +308,6 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
with
open
(
args
.
config_file
)
as
f
:
...
...
@@ -304,6 +315,3 @@ if __name__ == "__main__":
converted_checkpoint
=
convert_ldm_checkpoint
(
checkpoint
,
config
)
torch
.
save
(
checkpoint
,
args
.
dump_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment