Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
462956be
Unverified
Commit
462956be
authored
Jun 05, 2023
by
Will Berman
Committed by
GitHub
Jun 05, 2023
Browse files
small tweaks for parsing thibaudz controlnet checkpoints (#3657)
parent
59900147
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
30 deletions
+87
-30
scripts/convert_original_controlnet_to_diffusers.py
scripts/convert_original_controlnet_to_diffusers.py
+18
-0
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
...diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+69
-30
No files found.
scripts/convert_original_controlnet_to_diffusers.py
View file @
462956be
...
...
@@ -75,6 +75,22 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
help
=
"Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
)
# small workaround to get argparser to parse a boolean input as either true _or_ false
def
parse_bool
(
string
):
if
string
==
"True"
:
return
True
elif
string
==
"False"
:
return
False
else
:
raise
ValueError
(
f
"could not parse string as bool
{
string
}
"
)
parser
.
add_argument
(
"--use_linear_projection"
,
help
=
"Override for use linear projection"
,
required
=
False
,
type
=
parse_bool
)
parser
.
add_argument
(
"--cross_attention_dim"
,
help
=
"Override for cross attention_dim"
,
required
=
False
,
type
=
int
)
args
=
parser
.
parse_args
()
controlnet
=
download_controlnet_from_original_ckpt
(
...
...
@@ -86,6 +102,8 @@ if __name__ == "__main__":
upcast_attention
=
args
.
upcast_attention
,
from_safetensors
=
args
.
from_safetensors
,
device
=
args
.
device
,
use_linear_projection
=
args
.
use_linear_projection
,
cross_attention_dim
=
args
.
cross_attention_dim
,
)
controlnet
.
save_pretrained
(
args
.
dump_path
,
safe_serialization
=
args
.
to_safetensors
)
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
View file @
462956be
...
...
@@ -339,11 +339,16 @@ def create_ldm_bert_config(original_config):
return
config
def
convert_ldm_unet_checkpoint
(
checkpoint
,
config
,
path
=
None
,
extract_ema
=
False
,
controlnet
=
False
):
def
convert_ldm_unet_checkpoint
(
checkpoint
,
config
,
path
=
None
,
extract_ema
=
False
,
controlnet
=
False
,
skip_extract_state_dict
=
False
):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
if
skip_extract_state_dict
:
unet_state_dict
=
checkpoint
else
:
# extract state_dict for UNet
unet_state_dict
=
{}
keys
=
list
(
checkpoint
.
keys
())
...
...
@@ -956,17 +961,42 @@ def stable_unclip_image_noising_components(
def
convert_controlnet_checkpoint
(
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
,
use_linear_projection
=
None
,
cross_attention_dim
=
None
,
):
ctrlnet_config
=
create_unet_diffusers_config
(
original_config
,
image_size
=
image_size
,
controlnet
=
True
)
ctrlnet_config
[
"upcast_attention"
]
=
upcast_attention
ctrlnet_config
.
pop
(
"sample_size"
)
if
use_linear_projection
is
not
None
:
ctrlnet_config
[
"use_linear_projection"
]
=
use_linear_projection
if
cross_attention_dim
is
not
None
:
ctrlnet_config
[
"cross_attention_dim"
]
=
cross_attention_dim
controlnet_model
=
ControlNetModel
(
**
ctrlnet_config
)
# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
if
"time_embed.0.weight"
in
checkpoint
:
skip_extract_state_dict
=
True
else
:
skip_extract_state_dict
=
False
converted_ctrl_checkpoint
=
convert_ldm_unet_checkpoint
(
checkpoint
,
ctrlnet_config
,
path
=
checkpoint_path
,
extract_ema
=
extract_ema
,
controlnet
=
True
checkpoint
,
ctrlnet_config
,
path
=
checkpoint_path
,
extract_ema
=
extract_ema
,
controlnet
=
True
,
skip_extract_state_dict
=
skip_extract_state_dict
,
)
controlnet_model
.
load_state_dict
(
converted_ctrl_checkpoint
)
...
...
@@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt(
upcast_attention
:
Optional
[
bool
]
=
None
,
device
:
str
=
None
,
from_safetensors
:
bool
=
False
,
use_linear_projection
:
Optional
[
bool
]
=
None
,
cross_attention_dim
:
Optional
[
bool
]
=
None
,
)
->
DiffusionPipeline
:
if
not
is_omegaconf_available
():
raise
ValueError
(
BACKENDS_MAPPING
[
"omegaconf"
][
1
])
...
...
@@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt(
raise
ValueError
(
"`control_stage_config` not present in original config"
)
controlnet_model
=
convert_controlnet_checkpoint
(
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
,
use_linear_projection
=
use_linear_projection
,
cross_attention_dim
=
cross_attention_dim
,
)
return
controlnet_model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment