Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
462956be
Unverified
Commit
462956be
authored
Jun 05, 2023
by
Will Berman
Committed by
GitHub
Jun 05, 2023
Browse files
small tweaks for parsing thibaudz controlnet checkpoints (#3657)
parent
59900147
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
30 deletions
+87
-30
scripts/convert_original_controlnet_to_diffusers.py
scripts/convert_original_controlnet_to_diffusers.py
+18
-0
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
...diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+69
-30
No files found.
scripts/convert_original_controlnet_to_diffusers.py
View file @
462956be
...
@@ -75,6 +75,22 @@ if __name__ == "__main__":
...
@@ -75,6 +75,22 @@ if __name__ == "__main__":
)
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output model."
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
help
=
"Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
help
=
"Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
)
# small workaround to get argparser to parse a boolean input as either true _or_ false
def
parse_bool
(
string
):
if
string
==
"True"
:
return
True
elif
string
==
"False"
:
return
False
else
:
raise
ValueError
(
f
"could not parse string as bool
{
string
}
"
)
parser
.
add_argument
(
"--use_linear_projection"
,
help
=
"Override for use linear projection"
,
required
=
False
,
type
=
parse_bool
)
parser
.
add_argument
(
"--cross_attention_dim"
,
help
=
"Override for cross attention_dim"
,
required
=
False
,
type
=
int
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
controlnet
=
download_controlnet_from_original_ckpt
(
controlnet
=
download_controlnet_from_original_ckpt
(
...
@@ -86,6 +102,8 @@ if __name__ == "__main__":
...
@@ -86,6 +102,8 @@ if __name__ == "__main__":
upcast_attention
=
args
.
upcast_attention
,
upcast_attention
=
args
.
upcast_attention
,
from_safetensors
=
args
.
from_safetensors
,
from_safetensors
=
args
.
from_safetensors
,
device
=
args
.
device
,
device
=
args
.
device
,
use_linear_projection
=
args
.
use_linear_projection
,
cross_attention_dim
=
args
.
cross_attention_dim
,
)
)
controlnet
.
save_pretrained
(
args
.
dump_path
,
safe_serialization
=
args
.
to_safetensors
)
controlnet
.
save_pretrained
(
args
.
dump_path
,
safe_serialization
=
args
.
to_safetensors
)
src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
View file @
462956be
...
@@ -339,11 +339,16 @@ def create_ldm_bert_config(original_config):
...
@@ -339,11 +339,16 @@ def create_ldm_bert_config(original_config):
return
config
return
config
def
convert_ldm_unet_checkpoint
(
checkpoint
,
config
,
path
=
None
,
extract_ema
=
False
,
controlnet
=
False
):
def
convert_ldm_unet_checkpoint
(
checkpoint
,
config
,
path
=
None
,
extract_ema
=
False
,
controlnet
=
False
,
skip_extract_state_dict
=
False
):
"""
"""
Takes a state dict and a config, and returns a converted checkpoint.
Takes a state dict and a config, and returns a converted checkpoint.
"""
"""
if
skip_extract_state_dict
:
unet_state_dict
=
checkpoint
else
:
# extract state_dict for UNet
# extract state_dict for UNet
unet_state_dict
=
{}
unet_state_dict
=
{}
keys
=
list
(
checkpoint
.
keys
())
keys
=
list
(
checkpoint
.
keys
())
...
@@ -956,17 +961,42 @@ def stable_unclip_image_noising_components(
...
@@ -956,17 +961,42 @@ def stable_unclip_image_noising_components(
def
convert_controlnet_checkpoint
(
def
convert_controlnet_checkpoint
(
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
,
use_linear_projection
=
None
,
cross_attention_dim
=
None
,
):
):
ctrlnet_config
=
create_unet_diffusers_config
(
original_config
,
image_size
=
image_size
,
controlnet
=
True
)
ctrlnet_config
=
create_unet_diffusers_config
(
original_config
,
image_size
=
image_size
,
controlnet
=
True
)
ctrlnet_config
[
"upcast_attention"
]
=
upcast_attention
ctrlnet_config
[
"upcast_attention"
]
=
upcast_attention
ctrlnet_config
.
pop
(
"sample_size"
)
ctrlnet_config
.
pop
(
"sample_size"
)
if
use_linear_projection
is
not
None
:
ctrlnet_config
[
"use_linear_projection"
]
=
use_linear_projection
if
cross_attention_dim
is
not
None
:
ctrlnet_config
[
"cross_attention_dim"
]
=
cross_attention_dim
controlnet_model
=
ControlNetModel
(
**
ctrlnet_config
)
controlnet_model
=
ControlNetModel
(
**
ctrlnet_config
)
# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
if
"time_embed.0.weight"
in
checkpoint
:
skip_extract_state_dict
=
True
else
:
skip_extract_state_dict
=
False
converted_ctrl_checkpoint
=
convert_ldm_unet_checkpoint
(
converted_ctrl_checkpoint
=
convert_ldm_unet_checkpoint
(
checkpoint
,
ctrlnet_config
,
path
=
checkpoint_path
,
extract_ema
=
extract_ema
,
controlnet
=
True
checkpoint
,
ctrlnet_config
,
path
=
checkpoint_path
,
extract_ema
=
extract_ema
,
controlnet
=
True
,
skip_extract_state_dict
=
skip_extract_state_dict
,
)
)
controlnet_model
.
load_state_dict
(
converted_ctrl_checkpoint
)
controlnet_model
.
load_state_dict
(
converted_ctrl_checkpoint
)
...
@@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt(
...
@@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt(
upcast_attention
:
Optional
[
bool
]
=
None
,
upcast_attention
:
Optional
[
bool
]
=
None
,
device
:
str
=
None
,
device
:
str
=
None
,
from_safetensors
:
bool
=
False
,
from_safetensors
:
bool
=
False
,
use_linear_projection
:
Optional
[
bool
]
=
None
,
cross_attention_dim
:
Optional
[
bool
]
=
None
,
)
->
DiffusionPipeline
:
)
->
DiffusionPipeline
:
if
not
is_omegaconf_available
():
if
not
is_omegaconf_available
():
raise
ValueError
(
BACKENDS_MAPPING
[
"omegaconf"
][
1
])
raise
ValueError
(
BACKENDS_MAPPING
[
"omegaconf"
][
1
])
...
@@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt(
...
@@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt(
raise
ValueError
(
"`control_stage_config` not present in original config"
)
raise
ValueError
(
"`control_stage_config` not present in original config"
)
controlnet_model
=
convert_controlnet_checkpoint
(
controlnet_model
=
convert_controlnet_checkpoint
(
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
checkpoint
,
original_config
,
checkpoint_path
,
image_size
,
upcast_attention
,
extract_ema
,
use_linear_projection
=
use_linear_projection
,
cross_attention_dim
=
cross_attention_dim
,
)
)
return
controlnet_model
return
controlnet_model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment