Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
6cabc599
Unverified
Commit
6cabc599
authored
Jul 19, 2022
by
Lysandre Debut
Committed by
GitHub
Jul 19, 2022
Browse files
DDPM Conversion (#94)
* DDPM * Fixes * Edit tests
parent
36b459f6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
243 additions
and
12 deletions
+243
-12
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
+231
-0
scripts/convert_ldm_original_checkpoint_to_diffusers.py
scripts/convert_ldm_original_checkpoint_to_diffusers.py
+4
-4
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+8
-8
No files found.
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
0 → 100644
View file @
6cabc599
from
diffusers
import
UNetUnconditionalModel
import
argparse
import
json
import
torch
def
shave_segments
(
path
,
n_shave_prefix_segments
=
1
):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if
n_shave_prefix_segments
>=
0
:
return
'.'
.
join
(
path
.
split
(
'.'
)[
n_shave_prefix_segments
:])
else
:
return
'.'
.
join
(
path
.
split
(
'.'
)[:
n_shave_prefix_segments
])
def
renew_resnet_paths
(
old_list
,
n_shave_prefix_segments
=
0
):
mapping
=
[]
for
old_item
in
old_list
:
new_item
=
old_item
new_item
=
new_item
.
replace
(
'block.'
,
'resnets.'
)
new_item
=
new_item
.
replace
(
'conv_shorcut'
,
'conv1'
)
new_item
=
new_item
.
replace
(
'nin_shortcut'
,
'conv_shortcut'
)
new_item
=
new_item
.
replace
(
'temb_proj'
,
'time_emb_proj'
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'old'
:
old_item
,
'new'
:
new_item
})
return
mapping
def
renew_attention_paths
(
old_list
,
n_shave_prefix_segments
=
0
,
in_mid
=
False
):
mapping
=
[]
for
old_item
in
old_list
:
new_item
=
old_item
# In `model.mid`, the layer is called `attn`.
if
not
in_mid
:
new_item
=
new_item
.
replace
(
'attn'
,
'attentions'
)
new_item
=
new_item
.
replace
(
'.k.'
,
'.key.'
)
new_item
=
new_item
.
replace
(
'.v.'
,
'.value.'
)
new_item
=
new_item
.
replace
(
'.q.'
,
'.query.'
)
new_item
=
new_item
.
replace
(
'proj_out'
,
'proj_attn'
)
new_item
=
new_item
.
replace
(
'norm'
,
'group_norm'
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'old'
:
old_item
,
'new'
:
new_item
})
return
mapping
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
assert
isinstance
(
paths
,
list
),
"Paths should be a list of dicts containing 'old' and 'new' keys."
if
attention_paths_to_split
is
not
None
:
if
config
is
None
:
raise
ValueError
(
f
"Please specify the config if setting 'attention_paths_to_split' to 'True'."
)
for
path
,
path_map
in
attention_paths_to_split
.
items
():
old_tensor
=
old_checkpoint
[
path
]
channels
=
old_tensor
.
shape
[
0
]
//
3
target_shape
=
(
-
1
,
channels
)
if
len
(
old_tensor
.
shape
)
==
3
else
(
-
1
)
num_heads
=
old_tensor
.
shape
[
0
]
//
config
[
"num_head_channels"
]
//
3
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
checkpoint
[
path_map
[
'query'
]]
=
query
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
'key'
]]
=
key
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
'value'
]]
=
value
.
reshape
(
target_shape
).
squeeze
()
for
path
in
paths
:
new_path
=
path
[
'new'
]
if
attention_paths_to_split
is
not
None
and
new_path
in
attention_paths_to_split
:
continue
new_path
=
new_path
.
replace
(
'down.'
,
'downsample_blocks.'
)
new_path
=
new_path
.
replace
(
'up.'
,
'upsample_blocks.'
)
if
additional_replacements
is
not
None
:
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'old'
],
replacement
[
'new'
])
if
'attentions'
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]].
squeeze
()
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]]
def
convert_ddpm_checkpoint
(
checkpoint
,
config
):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
new_checkpoint
=
{}
new_checkpoint
[
'time_embedding.linear_1.weight'
]
=
checkpoint
[
'temb.dense.0.weight'
]
new_checkpoint
[
'time_embedding.linear_1.bias'
]
=
checkpoint
[
'temb.dense.0.bias'
]
new_checkpoint
[
'time_embedding.linear_2.weight'
]
=
checkpoint
[
'temb.dense.1.weight'
]
new_checkpoint
[
'time_embedding.linear_2.bias'
]
=
checkpoint
[
'temb.dense.1.bias'
]
new_checkpoint
[
'conv_norm_out.weight'
]
=
checkpoint
[
'norm_out.weight'
]
new_checkpoint
[
'conv_norm_out.bias'
]
=
checkpoint
[
'norm_out.bias'
]
new_checkpoint
[
'conv_in.weight'
]
=
checkpoint
[
'conv_in.weight'
]
new_checkpoint
[
'conv_in.bias'
]
=
checkpoint
[
'conv_in.bias'
]
new_checkpoint
[
'conv_out.weight'
]
=
checkpoint
[
'conv_out.weight'
]
new_checkpoint
[
'conv_out.bias'
]
=
checkpoint
[
'conv_out.bias'
]
num_downsample_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'down'
in
layer
})
downsample_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'down.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_downsample_blocks
)}
num_upsample_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'up'
in
layer
})
upsample_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'up.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_upsample_blocks
)}
for
i
in
range
(
num_downsample_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'num_res_blocks'
]
+
1
)
layer_in_block_id
=
(
i
-
1
)
%
(
config
[
'num_res_blocks'
]
+
1
)
if
any
(
'downsample'
in
layer
for
layer
in
downsample_blocks
[
i
]):
new_checkpoint
[
f
'downsample_blocks.
{
i
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.conv.weight'
]
new_checkpoint
[
f
'downsample_blocks.
{
i
}
.downsamplers.0.conv.bias'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.conv.bias'
]
new_checkpoint
[
f
'downsample_blocks.
{
i
}
.downsamplers.0.op.weight'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.conv.weight'
]
new_checkpoint
[
f
'downsample_blocks.
{
i
}
.downsamplers.0.op.bias'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.conv.bias'
]
if
any
(
'block'
in
layer
for
layer
in
downsample_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
downsample_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
downsample_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'num_res_blocks'
]):
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
if
any
(
'attn'
in
layer
for
layer
in
downsample_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
downsample_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
downsample_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'num_res_blocks'
]):
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
mid_block_1_layers
=
[
key
for
key
in
checkpoint
if
"mid.block_1"
in
key
]
mid_block_2_layers
=
[
key
for
key
in
checkpoint
if
"mid.block_2"
in
key
]
mid_attn_1_layers
=
[
key
for
key
in
checkpoint
if
"mid.attn_1"
in
key
]
# Mid new 2
paths
=
renew_resnet_paths
(
mid_block_1_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_1'
,
'new'
:
'resnets.0'
}
])
paths
=
renew_resnet_paths
(
mid_block_2_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_2'
,
'new'
:
'resnets.1'
}
])
paths
=
renew_attention_paths
(
mid_attn_1_layers
,
in_mid
=
True
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'attn_1'
,
'new'
:
'attentions.0'
}
])
for
i
in
range
(
num_upsample_blocks
):
block_id
=
num_upsample_blocks
-
1
-
i
if
any
(
'upsample'
in
layer
for
layer
in
upsample_blocks
[
i
]):
new_checkpoint
[
f
'upsample_blocks.
{
block_id
}
.upsamplers.0.conv.weight'
]
=
checkpoint
[
f
'up.
{
i
}
.upsample.conv.weight'
]
new_checkpoint
[
f
'upsample_blocks.
{
block_id
}
.upsamplers.0.conv.bias'
]
=
checkpoint
[
f
'up.
{
i
}
.upsample.conv.bias'
]
if
any
(
'block'
in
layer
for
layer
in
upsample_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
upsample_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
upsample_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'num_res_blocks'
]
+
1
):
replace_indices
=
{
'old'
:
f
'upsample_blocks.
{
i
}
'
,
'new'
:
f
'upsample_blocks.
{
block_id
}
'
}
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
if
any
(
'attn'
in
layer
for
layer
in
upsample_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
upsample_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
upsample_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'num_res_blocks'
]
+
1
):
replace_indices
=
{
'old'
:
f
'upsample_blocks.
{
i
}
'
,
'new'
:
f
'upsample_blocks.
{
block_id
}
'
}
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
new_checkpoint
=
{
k
.
replace
(
'mid_new_2'
,
'mid'
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
return
new_checkpoint
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the checkpoint to convert."
)
parser
.
add_argument
(
"--config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the architecture."
,
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
with
open
(
args
.
config_file
)
as
f
:
config
=
json
.
loads
(
f
.
read
())
converted_checkpoint
=
convert_ddpm_checkpoint
(
args
.
checkpoint_path
,
args
.
config_file
)
torch
.
save
(
converted_checkpoint
,
args
.
dump_path
)
scripts/convert_ldm_original_checkpoint_to_diffusers.py
View file @
6cabc599
...
@@ -124,7 +124,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
...
@@ -124,7 +124,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
def
convert_ldm_checkpoint
(
checkpoint
,
config
):
def
convert_ldm_checkpoint
(
checkpoint
,
config
):
"""
"""
Takes a state dict and
the path to
Takes a state dict and
a config, and returns a converted checkpoint.
"""
"""
new_checkpoint
=
{}
new_checkpoint
=
{}
...
@@ -142,15 +142,15 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -142,15 +142,15 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint
[
'conv_out.bias'
]
=
checkpoint
[
'out.2.bias'
]
new_checkpoint
[
'conv_out.bias'
]
=
checkpoint
[
'out.2.bias'
]
# Retrieves the keys for the input blocks only
# Retrieves the keys for the input blocks only
num_input_blocks
=
len
({
shave_segments
(
layer
,
-
2
)
for
layer
in
checkpoint
if
'input_blocks'
in
layer
})
num_input_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
]
)
for
layer
in
checkpoint
if
'input_blocks'
in
layer
})
input_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'input_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_input_blocks
)}
input_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'input_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_input_blocks
)}
# Retrieves the keys for the middle blocks only
# Retrieves the keys for the middle blocks only
num_middle_blocks
=
len
({
shave_segments
(
layer
,
-
2
)
for
layer
in
checkpoint
if
'middle_block'
in
layer
})
num_middle_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
]
)
for
layer
in
checkpoint
if
'middle_block'
in
layer
})
middle_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'middle_block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_middle_blocks
)}
middle_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'middle_block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_middle_blocks
)}
# Retrieves the keys for the output blocks only
# Retrieves the keys for the output blocks only
num_output_blocks
=
len
({
shave_segments
(
layer
,
-
2
)
for
layer
in
checkpoint
if
'output_blocks'
in
layer
})
num_output_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
]
)
for
layer
in
checkpoint
if
'output_blocks'
in
layer
})
output_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'output_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_output_blocks
)}
output_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'output_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_output_blocks
)}
for
i
in
range
(
1
,
num_input_blocks
):
for
i
in
range
(
1
,
num_input_blocks
):
...
...
tests/test_modeling_utils.py
View file @
6cabc599
...
@@ -929,7 +929,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -929,7 +929,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model_path
=
"
fusing
/ddpm-cifar10"
model_path
=
"
google
/ddpm-cifar10"
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
...
@@ -947,9 +947,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -947,9 +947,9 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ddpm_cifar10
(
self
):
def
test_ddpm_cifar10
(
self
):
model_id
=
"
fusing
/ddpm-cifar10"
model_id
=
"
google
/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
)
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
...
@@ -968,9 +968,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -968,9 +968,9 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ddim_lsun
(
self
):
def
test_ddim_lsun
(
self
):
model_id
=
"
fusing
/ddpm-lsun-bedroom-ema"
model_id
=
"
google
/ddpm-lsun-bedroom-ema"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
)
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
...
@@ -988,9 +988,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -988,9 +988,9 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ddim_cifar10
(
self
):
def
test_ddim_cifar10
(
self
):
model_id
=
"
fusing
/ddpm-cifar10"
model_id
=
"
google
/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
)
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
...
@@ -1008,7 +1008,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1008,7 +1008,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_pndm_cifar10
(
self
):
def
test_pndm_cifar10
(
self
):
model_id
=
"
fusing
/ddpm-cifar10"
model_id
=
"
google
/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment