Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3f1e9592
Commit
3f1e9592
authored
Jul 15, 2022
by
Patrick von Platen
Browse files
Fix conversion script
parent
87060e6a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
113 additions
and
19 deletions
+113
-19
debug_conversion.py
debug_conversion.py
+86
-0
scripts/__init__.py
scripts/__init__.py
+0
-0
scripts/convert_ldm_original_checkpoint_to_diffusers.py
scripts/convert_ldm_original_checkpoint_to_diffusers.py
+27
-19
No files found.
debug_conversion.py
0 → 100755
View file @
3f1e9592
#!/usr/bin/env python3
import
json
import
os
from
diffusers
import
UNetUnconditionalModel
from
scripts.convert_ldm_original_checkpoint_to_diffusers
import
convert_ldm_checkpoint
from
huggingface_hub
import
hf_hub_download
import
torch
model_id
=
"fusing/latent-diffusion-celeba-256"
subfolder
=
"unet"
#model_id = "fusing/unet-ldm-dummy"
#subfolder = None
checkpoint
=
"diffusion_model.pt"
config
=
"config.json"
if
subfolder
is
not
None
:
checkpoint
=
os
.
path
.
join
(
subfolder
,
checkpoint
)
config
=
os
.
path
.
join
(
subfolder
,
config
)
original_checkpoint
=
torch
.
load
(
hf_hub_download
(
model_id
,
checkpoint
))
config_path
=
hf_hub_download
(
model_id
,
config
)
with
open
(
config_path
)
as
f
:
config
=
json
.
load
(
f
)
checkpoint
=
convert_ldm_checkpoint
(
original_checkpoint
,
config
)
def
current_codebase_conversion
():
model
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
subfolder
=
subfolder
,
ldm
=
True
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
image_size
,
model
.
config
.
image_size
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
return
model
.
state_dict
()
currently_converted_checkpoint
=
current_codebase_conversion
()
torch
.
save
(
currently_converted_checkpoint
,
'currently_converted_checkpoint.pt'
)
def
diff_between_checkpoints
(
ch_0
,
ch_1
):
all_layers_included
=
False
if
not
set
(
ch_0
.
keys
())
==
set
(
ch_1
.
keys
()):
print
(
f
"Contained in ch_0 and not in ch_1 (Total:
{
len
((
set
(
ch_0
.
keys
())
-
set
(
ch_1
.
keys
())))
}
)"
)
for
key
in
sorted
(
list
((
set
(
ch_0
.
keys
())
-
set
(
ch_1
.
keys
())))):
print
(
f
"
\t
{
key
}
"
)
print
(
f
"Contained in ch_1 and not in ch_0 (Total:
{
len
((
set
(
ch_1
.
keys
())
-
set
(
ch_0
.
keys
())))
}
)"
)
for
key
in
sorted
(
list
((
set
(
ch_1
.
keys
())
-
set
(
ch_0
.
keys
())))):
print
(
f
"
\t
{
key
}
"
)
else
:
print
(
"Keys are the same between the two checkpoints"
)
all_layers_included
=
True
keys
=
ch_0
.
keys
()
non_equal_keys
=
[]
if
all_layers_included
:
for
key
in
keys
:
try
:
if
not
torch
.
allclose
(
ch_0
[
key
].
cpu
(),
ch_1
[
key
].
cpu
()):
non_equal_keys
.
append
(
f
'
{
key
}
. Diff:
{
torch
.
max
(
torch
.
abs
(
ch_0
[
key
].
cpu
()
-
ch_1
[
key
].
cpu
()))
}
'
)
except
RuntimeError
as
e
:
print
(
e
)
non_equal_keys
.
append
(
f
'
{
key
}
. Diff in shape:
{
ch_0
[
key
].
size
()
}
vs
{
ch_1
[
key
].
size
()
}
'
)
if
len
(
non_equal_keys
):
non_equal_keys
=
'
\n\t
'
.
join
(
non_equal_keys
)
print
(
f
"These keys do not satisfy equivalence requirement:
\n\t
{
non_equal_keys
}
"
)
else
:
print
(
"All keys are equal across checkpoints."
)
diff_between_checkpoints
(
currently_converted_checkpoint
,
checkpoint
)
scripts/__init__.py
0 → 100644
View file @
3f1e9592
scripts/convert_ldm_original_checkpoint_to_diffusers.py
View file @
3f1e9592
...
@@ -72,7 +72,7 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
...
@@ -72,7 +72,7 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
return
mapping
return
mapping
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
):
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
"""
"""
This does the final conversion step: take locally converted weights and apply a global renaming
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
to them. It splits attention layers, and takes into account additional replacements
...
@@ -85,11 +85,19 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
...
@@ -85,11 +85,19 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
# Splits the attention layers into three variables.
# Splits the attention layers into three variables.
if
attention_paths_to_split
is
not
None
:
if
attention_paths_to_split
is
not
None
:
for
path
,
path_map
in
attention_paths_to_split
.
items
():
for
path
,
path_map
in
attention_paths_to_split
.
items
():
query
,
key
,
value
=
torch
.
split
(
old_checkpoint
[
path
],
int
(
old_checkpoint
[
path
].
shape
[
0
]
/
3
))
old_tensor
=
old_checkpoint
[
path
]
channels
=
old_tensor
.
shape
[
0
]
//
3
checkpoint
[
path_map
[
'query'
]]
=
query
target_shape
=
(
-
1
,
channels
)
if
len
(
old_tensor
.
shape
)
==
3
else
(
-
1
)
checkpoint
[
path_map
[
'key'
]]
=
key
checkpoint
[
path_map
[
'value'
]]
=
value
num_heads
=
old_tensor
.
shape
[
0
]
//
config
[
"num_head_channels"
]
//
3
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
checkpoint
[
path_map
[
'query'
]]
=
query
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
'key'
]]
=
key
.
reshape
(
target_shape
)
checkpoint
[
path_map
[
'value'
]]
=
value
.
reshape
(
target_shape
)
for
path
in
paths
:
for
path
in
paths
:
new_path
=
path
[
'new'
]
new_path
=
path
[
'new'
]
...
@@ -107,7 +115,11 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
...
@@ -107,7 +115,11 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
for
replacement
in
additional_replacements
:
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'old'
],
replacement
[
'new'
])
new_path
=
new_path
.
replace
(
replacement
[
'old'
],
replacement
[
'new'
])
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]]
# proj_attn.weight has to be converted from conv 1D to linear
if
"proj_attn.weight"
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]][:,
:,
0
]
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]]
def
convert_ldm_checkpoint
(
checkpoint
,
config
):
def
convert_ldm_checkpoint
(
checkpoint
,
config
):
...
@@ -155,7 +167,7 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -155,7 +167,7 @@ def convert_ldm_checkpoint(checkpoint, config):
paths
=
renew_resnet_paths
(
resnets
)
paths
=
renew_resnet_paths
(
resnets
)
meta_path
=
{
'old'
:
f
'input_blocks.
{
i
}
.0'
,
'new'
:
f
'downsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
meta_path
=
{
'old'
:
f
'input_blocks.
{
i
}
.0'
,
'new'
:
f
'downsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
resnet_op
=
{
'old'
:
'resnets.2.op'
,
'new'
:
'downsamplers.0.op'
}
resnet_op
=
{
'old'
:
'resnets.2.op'
,
'new'
:
'downsamplers.0.op'
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
,
resnet_op
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
,
resnet_op
]
,
config
=
config
)
if
len
(
attentions
):
if
len
(
attentions
):
paths
=
renew_attention_paths
(
attentions
)
paths
=
renew_attention_paths
(
attentions
)
...
@@ -177,19 +189,19 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -177,19 +189,19 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint
,
new_checkpoint
,
checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
additional_replacements
=
[
meta_path
],
attention_paths_to_split
=
to_split
attention_paths_to_split
=
to_split
,
config
=
config
)
)
resnet_0
=
middle_blocks
[
0
]
resnet_0
=
middle_blocks
[
0
]
attentions
=
middle_blocks
[
1
]
attentions
=
middle_blocks
[
1
]
resnet_1
=
middle_blocks
[
2
]
resnet_1
=
middle_blocks
[
2
]
resnet_0_paths
=
renew_resnet_paths
(
resnet_0
)
resnet_0_paths
=
renew_resnet_paths
(
resnet_0
)
assign_to_checkpoint
(
resnet_0_paths
,
new_checkpoint
,
checkpoint
)
assign_to_checkpoint
(
resnet_0_paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
resnet_1_paths
=
renew_resnet_paths
(
resnet_1
)
resnet_1_paths
=
renew_resnet_paths
(
resnet_1
)
assign_to_checkpoint
(
resnet_1_paths
,
new_checkpoint
,
checkpoint
)
assign_to_checkpoint
(
resnet_1_paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
attentions_paths
=
renew_attention_paths
(
attentions
)
attentions_paths
=
renew_attention_paths
(
attentions
)
to_split
=
{
to_split
=
{
...
@@ -204,7 +216,7 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -204,7 +216,7 @@ def convert_ldm_checkpoint(checkpoint, config):
'value'
:
'mid.attentions.0.value.weight'
,
'value'
:
'mid.attentions.0.value.weight'
,
},
},
}
}
assign_to_checkpoint
(
attentions_paths
,
new_checkpoint
,
checkpoint
,
attention_paths_to_split
=
to_split
)
assign_to_checkpoint
(
attentions_paths
,
new_checkpoint
,
checkpoint
,
attention_paths_to_split
=
to_split
,
config
=
config
)
for
i
in
range
(
num_output_blocks
):
for
i
in
range
(
num_output_blocks
):
block_id
=
i
//
(
config
[
'num_res_blocks'
]
+
1
)
block_id
=
i
//
(
config
[
'num_res_blocks'
]
+
1
)
...
@@ -227,7 +239,7 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -227,7 +239,7 @@ def convert_ldm_checkpoint(checkpoint, config):
paths
=
renew_resnet_paths
(
resnets
)
paths
=
renew_resnet_paths
(
resnets
)
meta_path
=
{
'old'
:
f
'output_blocks.
{
i
}
.0'
,
'new'
:
f
'upsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
meta_path
=
{
'old'
:
f
'output_blocks.
{
i
}
.0'
,
'new'
:
f
'upsample_blocks.
{
block_id
}
.resnets.
{
layer_in_block_id
}
'
}
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
]
,
config
=
config
)
if
[
'conv.weight'
,
'conv.bias'
]
in
output_block_list
.
values
():
if
[
'conv.weight'
,
'conv.bias'
]
in
output_block_list
.
values
():
index
=
list
(
output_block_list
.
values
()).
index
([
'conv.weight'
,
'conv.bias'
])
index
=
list
(
output_block_list
.
values
()).
index
([
'conv.weight'
,
'conv.bias'
])
...
@@ -238,7 +250,6 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -238,7 +250,6 @@ def convert_ldm_checkpoint(checkpoint, config):
if
len
(
attentions
)
==
2
:
if
len
(
attentions
)
==
2
:
attentions
=
[]
attentions
=
[]
if
len
(
attentions
):
if
len
(
attentions
):
paths
=
renew_attention_paths
(
attentions
)
paths
=
renew_attention_paths
(
attentions
)
meta_path
=
{
meta_path
=
{
...
@@ -262,7 +273,8 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -262,7 +273,8 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint
,
new_checkpoint
,
checkpoint
,
checkpoint
,
additional_replacements
=
[
meta_path
],
additional_replacements
=
[
meta_path
],
attention_paths_to_split
=
to_split
if
any
(
'qkv'
in
key
for
key
in
attentions
)
else
None
attention_paths_to_split
=
to_split
if
any
(
'qkv'
in
key
for
key
in
attentions
)
else
None
,
config
=
config
,
)
)
else
:
else
:
resnet_0_paths
=
renew_resnet_paths
(
output_block_layers
,
n_shave_prefix_segments
=
1
)
resnet_0_paths
=
renew_resnet_paths
(
output_block_layers
,
n_shave_prefix_segments
=
1
)
...
@@ -296,7 +308,6 @@ if __name__ == "__main__":
...
@@ -296,7 +308,6 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
with
open
(
args
.
config_file
)
as
f
:
with
open
(
args
.
config_file
)
as
f
:
...
@@ -304,6 +315,3 @@ if __name__ == "__main__":
...
@@ -304,6 +315,3 @@ if __name__ == "__main__":
converted_checkpoint
=
convert_ldm_checkpoint
(
checkpoint
,
config
)
converted_checkpoint
=
convert_ldm_checkpoint
(
checkpoint
,
config
)
torch
.
save
(
checkpoint
,
args
.
dump_path
)
torch
.
save
(
checkpoint
,
args
.
dump_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment