Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
6cabc599
Unverified
Commit
6cabc599
authored
Jul 19, 2022
by
Lysandre Debut
Committed by
GitHub
Jul 19, 2022
Browse files
DDPM Conversion (#94)
* DDPM * Fixes * Edit tests
parent
36b459f6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
243 additions
and
12 deletions
+243
-12
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
+231
-0
scripts/convert_ldm_original_checkpoint_to_diffusers.py
scripts/convert_ldm_original_checkpoint_to_diffusers.py
+4
-4
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+8
-8
No files found.
scripts/convert_ddpm_original_checkpoint_to_diffusers.py
0 → 100644
View file @
6cabc599
from
diffusers
import
UNetUnconditionalModel
import
argparse
import
json
import
torch
def
shave_segments
(
path
,
n_shave_prefix_segments
=
1
):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if
n_shave_prefix_segments
>=
0
:
return
'.'
.
join
(
path
.
split
(
'.'
)[
n_shave_prefix_segments
:])
else
:
return
'.'
.
join
(
path
.
split
(
'.'
)[:
n_shave_prefix_segments
])
def
renew_resnet_paths
(
old_list
,
n_shave_prefix_segments
=
0
):
mapping
=
[]
for
old_item
in
old_list
:
new_item
=
old_item
new_item
=
new_item
.
replace
(
'block.'
,
'resnets.'
)
new_item
=
new_item
.
replace
(
'conv_shorcut'
,
'conv1'
)
new_item
=
new_item
.
replace
(
'nin_shortcut'
,
'conv_shortcut'
)
new_item
=
new_item
.
replace
(
'temb_proj'
,
'time_emb_proj'
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'old'
:
old_item
,
'new'
:
new_item
})
return
mapping
def
renew_attention_paths
(
old_list
,
n_shave_prefix_segments
=
0
,
in_mid
=
False
):
mapping
=
[]
for
old_item
in
old_list
:
new_item
=
old_item
# In `model.mid`, the layer is called `attn`.
if
not
in_mid
:
new_item
=
new_item
.
replace
(
'attn'
,
'attentions'
)
new_item
=
new_item
.
replace
(
'.k.'
,
'.key.'
)
new_item
=
new_item
.
replace
(
'.v.'
,
'.value.'
)
new_item
=
new_item
.
replace
(
'.q.'
,
'.query.'
)
new_item
=
new_item
.
replace
(
'proj_out'
,
'proj_attn'
)
new_item
=
new_item
.
replace
(
'norm'
,
'group_norm'
)
new_item
=
shave_segments
(
new_item
,
n_shave_prefix_segments
=
n_shave_prefix_segments
)
mapping
.
append
({
'old'
:
old_item
,
'new'
:
new_item
})
return
mapping
def
assign_to_checkpoint
(
paths
,
checkpoint
,
old_checkpoint
,
attention_paths_to_split
=
None
,
additional_replacements
=
None
,
config
=
None
):
assert
isinstance
(
paths
,
list
),
"Paths should be a list of dicts containing 'old' and 'new' keys."
if
attention_paths_to_split
is
not
None
:
if
config
is
None
:
raise
ValueError
(
f
"Please specify the config if setting 'attention_paths_to_split' to 'True'."
)
for
path
,
path_map
in
attention_paths_to_split
.
items
():
old_tensor
=
old_checkpoint
[
path
]
channels
=
old_tensor
.
shape
[
0
]
//
3
target_shape
=
(
-
1
,
channels
)
if
len
(
old_tensor
.
shape
)
==
3
else
(
-
1
)
num_heads
=
old_tensor
.
shape
[
0
]
//
config
[
"num_head_channels"
]
//
3
old_tensor
=
old_tensor
.
reshape
((
num_heads
,
3
*
channels
//
num_heads
)
+
old_tensor
.
shape
[
1
:])
query
,
key
,
value
=
old_tensor
.
split
(
channels
//
num_heads
,
dim
=
1
)
checkpoint
[
path_map
[
'query'
]]
=
query
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
'key'
]]
=
key
.
reshape
(
target_shape
).
squeeze
()
checkpoint
[
path_map
[
'value'
]]
=
value
.
reshape
(
target_shape
).
squeeze
()
for
path
in
paths
:
new_path
=
path
[
'new'
]
if
attention_paths_to_split
is
not
None
and
new_path
in
attention_paths_to_split
:
continue
new_path
=
new_path
.
replace
(
'down.'
,
'downsample_blocks.'
)
new_path
=
new_path
.
replace
(
'up.'
,
'upsample_blocks.'
)
if
additional_replacements
is
not
None
:
for
replacement
in
additional_replacements
:
new_path
=
new_path
.
replace
(
replacement
[
'old'
],
replacement
[
'new'
])
if
'attentions'
in
new_path
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]].
squeeze
()
else
:
checkpoint
[
new_path
]
=
old_checkpoint
[
path
[
'old'
]]
def
convert_ddpm_checkpoint
(
checkpoint
,
config
):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
new_checkpoint
=
{}
new_checkpoint
[
'time_embedding.linear_1.weight'
]
=
checkpoint
[
'temb.dense.0.weight'
]
new_checkpoint
[
'time_embedding.linear_1.bias'
]
=
checkpoint
[
'temb.dense.0.bias'
]
new_checkpoint
[
'time_embedding.linear_2.weight'
]
=
checkpoint
[
'temb.dense.1.weight'
]
new_checkpoint
[
'time_embedding.linear_2.bias'
]
=
checkpoint
[
'temb.dense.1.bias'
]
new_checkpoint
[
'conv_norm_out.weight'
]
=
checkpoint
[
'norm_out.weight'
]
new_checkpoint
[
'conv_norm_out.bias'
]
=
checkpoint
[
'norm_out.bias'
]
new_checkpoint
[
'conv_in.weight'
]
=
checkpoint
[
'conv_in.weight'
]
new_checkpoint
[
'conv_in.bias'
]
=
checkpoint
[
'conv_in.bias'
]
new_checkpoint
[
'conv_out.weight'
]
=
checkpoint
[
'conv_out.weight'
]
new_checkpoint
[
'conv_out.bias'
]
=
checkpoint
[
'conv_out.bias'
]
num_downsample_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'down'
in
layer
})
downsample_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'down.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_downsample_blocks
)}
num_upsample_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
])
for
layer
in
checkpoint
if
'up'
in
layer
})
upsample_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'up.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_upsample_blocks
)}
for
i
in
range
(
num_downsample_blocks
):
block_id
=
(
i
-
1
)
//
(
config
[
'num_res_blocks'
]
+
1
)
layer_in_block_id
=
(
i
-
1
)
%
(
config
[
'num_res_blocks'
]
+
1
)
if
any
(
'downsample'
in
layer
for
layer
in
downsample_blocks
[
i
]):
new_checkpoint
[
f
'downsample_blocks.
{
i
}
.downsamplers.0.conv.weight'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.conv.weight'
]
new_checkpoint
[
f
'downsample_blocks.
{
i
}
.downsamplers.0.conv.bias'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.conv.bias'
]
new_checkpoint
[
f
'downsample_blocks.
{
i
}
.downsamplers.0.op.weight'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.conv.weight'
]
new_checkpoint
[
f
'downsample_blocks.
{
i
}
.downsamplers.0.op.bias'
]
=
checkpoint
[
f
'down.
{
i
}
.downsample.conv.bias'
]
if
any
(
'block'
in
layer
for
layer
in
downsample_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
downsample_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
downsample_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'num_res_blocks'
]):
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
)
if
any
(
'attn'
in
layer
for
layer
in
downsample_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
downsample_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
downsample_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'num_res_blocks'
]):
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
config
=
config
)
mid_block_1_layers
=
[
key
for
key
in
checkpoint
if
"mid.block_1"
in
key
]
mid_block_2_layers
=
[
key
for
key
in
checkpoint
if
"mid.block_2"
in
key
]
mid_attn_1_layers
=
[
key
for
key
in
checkpoint
if
"mid.attn_1"
in
key
]
# Mid new 2
paths
=
renew_resnet_paths
(
mid_block_1_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_1'
,
'new'
:
'resnets.0'
}
])
paths
=
renew_resnet_paths
(
mid_block_2_layers
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'block_2'
,
'new'
:
'resnets.1'
}
])
paths
=
renew_attention_paths
(
mid_attn_1_layers
,
in_mid
=
True
)
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
{
'old'
:
'mid.'
,
'new'
:
'mid_new_2.'
},
{
'old'
:
'attn_1'
,
'new'
:
'attentions.0'
}
])
for
i
in
range
(
num_upsample_blocks
):
block_id
=
num_upsample_blocks
-
1
-
i
if
any
(
'upsample'
in
layer
for
layer
in
upsample_blocks
[
i
]):
new_checkpoint
[
f
'upsample_blocks.
{
block_id
}
.upsamplers.0.conv.weight'
]
=
checkpoint
[
f
'up.
{
i
}
.upsample.conv.weight'
]
new_checkpoint
[
f
'upsample_blocks.
{
block_id
}
.upsamplers.0.conv.bias'
]
=
checkpoint
[
f
'up.
{
i
}
.upsample.conv.bias'
]
if
any
(
'block'
in
layer
for
layer
in
upsample_blocks
[
i
]):
num_blocks
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
upsample_blocks
[
i
]
if
'block'
in
layer
})
blocks
=
{
layer_id
:
[
key
for
key
in
upsample_blocks
[
i
]
if
f
'block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
num_blocks
>
0
:
for
j
in
range
(
config
[
'num_res_blocks'
]
+
1
):
replace_indices
=
{
'old'
:
f
'upsample_blocks.
{
i
}
'
,
'new'
:
f
'upsample_blocks.
{
block_id
}
'
}
paths
=
renew_resnet_paths
(
blocks
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
if
any
(
'attn'
in
layer
for
layer
in
upsample_blocks
[
i
]):
num_attn
=
len
({
'.'
.
join
(
shave_segments
(
layer
,
2
).
split
(
'.'
)[:
2
])
for
layer
in
upsample_blocks
[
i
]
if
'attn'
in
layer
})
attns
=
{
layer_id
:
[
key
for
key
in
upsample_blocks
[
i
]
if
f
'attn.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_blocks
)}
if
num_attn
>
0
:
for
j
in
range
(
config
[
'num_res_blocks'
]
+
1
):
replace_indices
=
{
'old'
:
f
'upsample_blocks.
{
i
}
'
,
'new'
:
f
'upsample_blocks.
{
block_id
}
'
}
paths
=
renew_attention_paths
(
attns
[
j
])
assign_to_checkpoint
(
paths
,
new_checkpoint
,
checkpoint
,
additional_replacements
=
[
replace_indices
])
new_checkpoint
=
{
k
.
replace
(
'mid_new_2'
,
'mid'
):
v
for
k
,
v
in
new_checkpoint
.
items
()}
return
new_checkpoint
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the checkpoint to convert."
)
parser
.
add_argument
(
"--config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the architecture."
,
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
args
=
parser
.
parse_args
()
checkpoint
=
torch
.
load
(
args
.
checkpoint_path
)
with
open
(
args
.
config_file
)
as
f
:
config
=
json
.
loads
(
f
.
read
())
converted_checkpoint
=
convert_ddpm_checkpoint
(
args
.
checkpoint_path
,
args
.
config_file
)
torch
.
save
(
converted_checkpoint
,
args
.
dump_path
)
scripts/convert_ldm_original_checkpoint_to_diffusers.py
View file @
6cabc599
...
@@ -124,7 +124,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
...
@@ -124,7 +124,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
def
convert_ldm_checkpoint
(
checkpoint
,
config
):
def
convert_ldm_checkpoint
(
checkpoint
,
config
):
"""
"""
Takes a state dict and
the path to
Takes a state dict and
a config, and returns a converted checkpoint.
"""
"""
new_checkpoint
=
{}
new_checkpoint
=
{}
...
@@ -142,15 +142,15 @@ def convert_ldm_checkpoint(checkpoint, config):
...
@@ -142,15 +142,15 @@ def convert_ldm_checkpoint(checkpoint, config):
new_checkpoint
[
'conv_out.bias'
]
=
checkpoint
[
'out.2.bias'
]
new_checkpoint
[
'conv_out.bias'
]
=
checkpoint
[
'out.2.bias'
]
# Retrieves the keys for the input blocks only
# Retrieves the keys for the input blocks only
num_input_blocks
=
len
({
shave_segments
(
layer
,
-
2
)
for
layer
in
checkpoint
if
'input_blocks'
in
layer
})
num_input_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
]
)
for
layer
in
checkpoint
if
'input_blocks'
in
layer
})
input_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'input_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_input_blocks
)}
input_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'input_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_input_blocks
)}
# Retrieves the keys for the middle blocks only
# Retrieves the keys for the middle blocks only
num_middle_blocks
=
len
({
shave_segments
(
layer
,
-
2
)
for
layer
in
checkpoint
if
'middle_block'
in
layer
})
num_middle_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
]
)
for
layer
in
checkpoint
if
'middle_block'
in
layer
})
middle_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'middle_block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_middle_blocks
)}
middle_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'middle_block.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_middle_blocks
)}
# Retrieves the keys for the output blocks only
# Retrieves the keys for the output blocks only
num_output_blocks
=
len
({
shave_segments
(
layer
,
-
2
)
for
layer
in
checkpoint
if
'output_blocks'
in
layer
})
num_output_blocks
=
len
({
'.'
.
join
(
layer
.
split
(
'.'
)[:
2
]
)
for
layer
in
checkpoint
if
'output_blocks'
in
layer
})
output_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'output_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_output_blocks
)}
output_blocks
=
{
layer_id
:
[
key
for
key
in
checkpoint
if
f
'output_blocks.
{
layer_id
}
'
in
key
]
for
layer_id
in
range
(
num_output_blocks
)}
for
i
in
range
(
1
,
num_input_blocks
):
for
i
in
range
(
1
,
num_input_blocks
):
...
...
tests/test_modeling_utils.py
View file @
6cabc599
...
@@ -929,7 +929,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -929,7 +929,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_from_pretrained_hub
(
self
):
def
test_from_pretrained_hub
(
self
):
model_path
=
"
fusing
/ddpm-cifar10"
model_path
=
"
google
/ddpm-cifar10"
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm
=
DDPMPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
ddpm_from_hub
=
DiffusionPipeline
.
from_pretrained
(
model_path
)
...
@@ -947,9 +947,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -947,9 +947,9 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ddpm_cifar10
(
self
):
def
test_ddpm_cifar10
(
self
):
model_id
=
"
fusing
/ddpm-cifar10"
model_id
=
"
google
/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
)
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
scheduler
=
DDPMScheduler
.
from_config
(
model_id
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
scheduler
=
scheduler
.
set_format
(
"pt"
)
...
@@ -968,9 +968,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -968,9 +968,9 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ddim_lsun
(
self
):
def
test_ddim_lsun
(
self
):
model_id
=
"
fusing
/ddpm-lsun-bedroom-ema"
model_id
=
"
google
/ddpm-lsun-bedroom-ema"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
)
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
scheduler
=
DDIMScheduler
.
from_config
(
model_id
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddpm
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
...
@@ -988,9 +988,9 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -988,9 +988,9 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_ddim_cifar10
(
self
):
def
test_ddim_cifar10
(
self
):
model_id
=
"
fusing
/ddpm-cifar10"
model_id
=
"
google
/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
)
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
DDIMScheduler
(
tensor_format
=
"pt"
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
ddim
=
DDIMPipeline
(
unet
=
unet
,
scheduler
=
scheduler
)
...
@@ -1008,7 +1008,7 @@ class PipelineTesterMixin(unittest.TestCase):
...
@@ -1008,7 +1008,7 @@ class PipelineTesterMixin(unittest.TestCase):
@
slow
@
slow
def
test_pndm_cifar10
(
self
):
def
test_pndm_cifar10
(
self
):
model_id
=
"
fusing
/ddpm-cifar10"
model_id
=
"
google
/ddpm-cifar10"
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
unet
=
UNetUnconditionalModel
.
from_pretrained
(
model_id
,
ddpm
=
True
)
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
scheduler
=
PNDMScheduler
(
tensor_format
=
"pt"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment