Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
462956be
Unverified
Commit
462956be
authored
Jun 05, 2023
by
Will Berman
Committed by
GitHub
Jun 05, 2023
Browse files
small tweaks for parsing thibaudz controlnet checkpoints (#3657)
parent
59900147
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
30 deletions
+87
-30
scripts/convert_original_controlnet_to_diffusers.py
scripts/convert_original_controlnet_to_diffusers.py
+18
-0
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
...diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+69
-30
No files found.
scripts/convert_original_controlnet_to_diffusers.py
View file @
462956be
...
@@ -75,6 +75,22 @@ if __name__ == "__main__":
...
@@ -75,6 +75,22 @@ if __name__ == "__main__":
)
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
help
=
"Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
help
=
"Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
)
# small workaround to get argparser to parse a boolean input as either true _or_ false
def
parse_bool
(
string
):
if
string
==
"True"
:
return
True
elif
string
==
"False"
:
return
False
else
:
raise
ValueError
(
f
"could not parse string as bool
{
string
}
"
)
parser
.
add_argument
(
"--use_linear_projection"
,
help
=
"Override for use linear projection"
,
required
=
False
,
type
=
parse_bool
)
parser
.
add_argument
(
"--cross_attention_dim"
,
help
=
"Override for cross attention_dim"
,
required
=
False
,
type
=
int
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
controlnet
=
download_controlnet_from_original_ckpt
(
controlnet
=
download_controlnet_from_original_ckpt
(
...
@@ -86,6 +102,8 @@ if __name__ == "__main__":
...
@@ -86,6 +102,8 @@ if __name__ == "__main__":
upcast_attention
=
args
.
upcast_attention
,
upcast_attention
=
args
.
upcast_attention
,
from_safetensors
=
args
.
from_safetensors
,
from_safetensors
=
args
.
from_safetensors
,
device
=
args
.
device
,
device
=
args
.
device
,
use_linear_projection
=
args
.
use_linear_projection
,
cross_attention_dim
=
args
.
cross_attention_dim
,
)
)
controlnet
.
save_pretrained
(
args
.
dump_path
,
safe_serialization
=
args
.
to_safetensors
)
controlnet
.
save_pretrained
(
args
.
dump_path
,
safe_serialization
=
args
.
to_safetensors
)
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
View file @
462956be
...
@@ -339,41 +339,46 @@ def create_ldm_bert_config(original_config):
...
@@ -339,41 +339,46 @@ def create_ldm_bert_config(original_config):
return
config
return
config
def
convert_ldm_unet_checkpoint
(
checkpoint
,
config
,
path
=
None
,
extract_ema
=
False
,
controlnet
=
False
):
def
convert_ldm_unet_checkpoint
(
checkpoint
,
config
,
path
=
None
,
extract_ema
=
False
,
controlnet
=
False
,
skip_extract_state_dict
=
False
):
"""
"""
Takes a state dict and a config, and returns a converted checkpoint.
Takes a state dict and a config, and returns a converted checkpoint.
"""
"""
# extract state_dict for UNet
if
skip_extract_state_dict
:
unet_state_dict
=
{}
unet_state_dict
=
checkpoint
keys
=
list
(
checkpoint
.
keys
())
if
controlnet
:
unet_key
=
"control_model."
else
:
else
:
unet_key
=
"model.diffusion_model."
# extract state_dict for UNet
unet_state_dict
=
{}
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
keys
=
list
(
checkpoint
.
keys
())
if
sum
(
k
.
startswith
(
"model_ema"
)
for
k
in
keys
)
>
100
and
extract_ema
:
print
(
f
"Checkpoint
{
path
}
has both EMA and non-EMA weights."
)
if
controlnet
:
print
(
unet_key
=
"control_model."
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
else
:
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
unet_key
=
"model.diffusion_model."
)
for
key
in
keys
:
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if
key
.
startswith
(
"model.diffusion_model"
):
if
sum
(
k
.
startswith
(
"model_ema"
)
for
k
in
keys
)
>
100
and
extract_ema
:
flat_ema_key
=
"model_ema."
+
""
.
join
(
key
.
split
(
"."
)[
1
:])
print
(
f
"Checkpoint
{
path
}
has both EMA and non-EMA weights."
)
unet_state_dict
[
key
.
replace
(
unet_key
,
""
)]
=
checkpoint
.
pop
(
flat_ema_key
)
else
:
if
sum
(
k
.
startswith
(
"model_ema"
)
for
k
in
keys
)
>
100
:
print
(
print
(
"In this conversion only the
non-
EMA weights are extracted. If you want to instead extract the EMA"
"In this conversion only the EMA weights are extracted. If you want to instead extract the
non-
EMA"
" weights (us
ually better for inference
), please make sure to
add
the `--extract_ema` flag."
" weights (us
eful to continue fine-tuning
), please make sure to
remove
the `--extract_ema` flag."
)
)
for
key
in
keys
:
if
key
.
startswith
(
"model.diffusion_model"
):
flat_ema_key
=
"model_ema."
+
""
.
join
(
key
.
split
(
"."
)[
1
:])
unet_state_dict
[
key
.
replace
(
unet_key
,
""
)]
=
checkpoint
.
pop
(
flat_ema_key
)
else
:
if
sum
(
k
.
startswith
(
"model_ema"
)
for
k
in
keys
)
>
100
:
print
(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for
key
in
keys
:
for
key
in
keys
:
if
key
.
startswith
(
unet_key
):
if
key
.
startswith
(
unet_key
):
unet_state_dict
[
key
.
replace
(
unet_key
,
""
)]
=
checkpoint
.
pop
(
key
)
unet_state_dict
[
key
.
replace
(
unet_key
,
""
)]
=
checkpoint
.
pop
(
key
)
new_checkpoint
=
{}
new_checkpoint
=
{}
...
@@ -956,17 +961,42 @@ def stable_unclip_image_noising_components(
...
@@ -956,17 +961,42 @@ def stable_unclip_image_noising_components(
def
convert_controlnet_checkpoint
(
def
convert_controlnet_checkpoint
(
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
,
use_linear_projection
=
None
,
cross_attention_dim
=
None
,
):
):
ctrlnet_config
=
create_unet_diffusers_config
(
original_config
,
image_size
=
image_size
,
controlnet
=
True
)
ctrlnet_config
=
create_unet_diffusers_config
(
original_config
,
image_size
=
image_size
,
controlnet
=
True
)
ctrlnet_config
[
"upcast_attention"
]
=
upcast_attention
ctrlnet_config
[
"upcast_attention"
]
=
upcast_attention
ctrlnet_config
.
pop
(
"sample_size"
)
ctrlnet_config
.
pop
(
"sample_size"
)
if
use_linear_projection
is
not
None
:
ctrlnet_config
[
"use_linear_projection"
]
=
use_linear_projection
if
cross_attention_dim
is
not
None
:
ctrlnet_config
[
"cross_attention_dim"
]
=
cross_attention_dim
controlnet_model
=
ControlNetModel
(
**
ctrlnet_config
)
controlnet_model
=
ControlNetModel
(
**
ctrlnet_config
)
# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
if
"time_embed.0.weight"
in
checkpoint
:
skip_extract_state_dict
=
True
else
:
skip_extract_state_dict
=
False
converted_ctrl_checkpoint
=
convert_ldm_unet_checkpoint
(
converted_ctrl_checkpoint
=
convert_ldm_unet_checkpoint
(
checkpoint
,
ctrlnet_config
,
path
=
checkpoint_path
,
extract_ema
=
extract_ema
,
controlnet
=
True
checkpoint
,
ctrlnet_config
,
path
=
checkpoint_path
,
extract_ema
=
extract_ema
,
controlnet
=
True
,
skip_extract_state_dict
=
skip_extract_state_dict
,
)
)
controlnet_model
.
load_state_dict
(
converted_ctrl_checkpoint
)
controlnet_model
.
load_state_dict
(
converted_ctrl_checkpoint
)
...
@@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt(
...
@@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt(
upcast_attention
:
Optional
[
bool
]
=
None
,
upcast_attention
:
Optional
[
bool
]
=
None
,
device
:
str
=
None
,
device
:
str
=
None
,
from_safetensors
:
bool
=
False
,
from_safetensors
:
bool
=
False
,
use_linear_projection
:
Optional
[
bool
]
=
None
,
cross_attention_dim
:
Optional
[
bool
]
=
None
,
)
->
DiffusionPipeline
:
)
->
DiffusionPipeline
:
if
not
is_omegaconf_available
():
if
not
is_omegaconf_available
():
raise
ValueError
(
BACKENDS_MAPPING
[
"omegaconf"
][
1
])
raise
ValueError
(
BACKENDS_MAPPING
[
"omegaconf"
][
1
])
...
@@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt(
...
@@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt(
raise
ValueError
(
"`control_stage_config` not present in original config"
)
raise
ValueError
(
"`control_stage_config` not present in original config"
)
controlnet_model
=
convert_controlnet_checkpoint
(
controlnet_model
=
convert_controlnet_checkpoint
(
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
,
use_linear_projection
=
use_linear_projection
,
cross_attention_dim
=
cross_attention_dim
,
)
)
return
controlnet_model
return
controlnet_model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment