Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
255ac592
Unverified
Commit
255ac592
authored
Aug 23, 2024
by
Dhruv Nair
Committed by
GitHub
Aug 23, 2024
Browse files
[Single File] Support loading Comfy UI Flux checkpoints (#9243)
update
parent
2d9ccf39
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
6 deletions
+18
-6
src/diffusers/loaders/single_file_utils.py
src/diffusers/loaders/single_file_utils.py
+18
-6
No files found.
src/diffusers/loaders/single_file_utils.py
View file @
255ac592
...
@@ -79,7 +79,10 @@ CHECKPOINT_KEY_NAMES = {
...
@@ -79,7 +79,10 @@ CHECKPOINT_KEY_NAMES = {
"animatediff_sdxl_beta"
:
"up_blocks.2.motion_modules.0.temporal_transformer.norm.weight"
,
"animatediff_sdxl_beta"
:
"up_blocks.2.motion_modules.0.temporal_transformer.norm.weight"
,
"animatediff_scribble"
:
"controlnet_cond_embedding.conv_in.weight"
,
"animatediff_scribble"
:
"controlnet_cond_embedding.conv_in.weight"
,
"animatediff_rgb"
:
"controlnet_cond_embedding.weight"
,
"animatediff_rgb"
:
"controlnet_cond_embedding.weight"
,
"flux"
:
"double_blocks.0.img_attn.norm.key_norm.scale"
,
"flux"
:
[
"double_blocks.0.img_attn.norm.key_norm.scale"
,
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale"
,
],
}
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS
=
{
DIFFUSERS_DEFAULT_PIPELINE_PATHS
=
{
...
@@ -258,7 +261,7 @@ SCHEDULER_DEFAULT_CONFIG = {
...
@@ -258,7 +261,7 @@ SCHEDULER_DEFAULT_CONFIG = {
"timestep_spacing"
:
"leading"
,
"timestep_spacing"
:
"leading"
,
}
}
LDM_VAE_KEY
=
"first_stage_model."
LDM_VAE_KEY
S
=
[
"first_stage_model."
,
"vae."
]
LDM_VAE_DEFAULT_SCALING_FACTOR
=
0.18215
LDM_VAE_DEFAULT_SCALING_FACTOR
=
0.18215
PLAYGROUND_VAE_SCALING_FACTOR
=
0.5
PLAYGROUND_VAE_SCALING_FACTOR
=
0.5
LDM_UNET_KEY
=
"model.diffusion_model."
LDM_UNET_KEY
=
"model.diffusion_model."
...
@@ -267,7 +270,6 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
...
@@ -267,7 +270,6 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
"cond_stage_model.transformer."
,
"cond_stage_model.transformer."
,
"conditioner.embedders.0.transformer."
,
"conditioner.embedders.0.transformer."
,
]
]
OPEN_CLIP_PREFIX
=
"conditioner.embedders.0.model."
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
=
1024
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
=
1024
SCHEDULER_LEGACY_KWARGS
=
[
"prediction_type"
,
"scheduler_type"
]
SCHEDULER_LEGACY_KWARGS
=
[
"prediction_type"
,
"scheduler_type"
]
...
@@ -523,8 +525,10 @@ def infer_diffusers_model_type(checkpoint):
...
@@ -523,8 +525,10 @@ def infer_diffusers_model_type(checkpoint):
else
:
else
:
model_type
=
"animatediff_v3"
model_type
=
"animatediff_v3"
elif
CHECKPOINT_KEY_NAMES
[
"flux"
]
in
checkpoint
:
elif
any
(
key
in
checkpoint
for
key
in
CHECKPOINT_KEY_NAMES
[
"flux"
]):
if
"guidance_in.in_layer.bias"
in
checkpoint
:
if
any
(
g
in
checkpoint
for
g
in
[
"guidance_in.in_layer.bias"
,
"model.diffusion_model.guidance_in.in_layer.bias"
]
):
model_type
=
"flux-dev"
model_type
=
"flux-dev"
else
:
else
:
model_type
=
"flux-schnell"
model_type
=
"flux-schnell"
...
@@ -1183,7 +1187,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
...
@@ -1183,7 +1187,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
vae_state_dict
=
{}
vae_state_dict
=
{}
keys
=
list
(
checkpoint
.
keys
())
keys
=
list
(
checkpoint
.
keys
())
vae_key
=
LDM_VAE_KEY
if
any
(
k
.
startswith
(
LDM_VAE_KEY
)
for
k
in
keys
)
else
""
vae_key
=
""
for
ldm_vae_key
in
LDM_VAE_KEYS
:
if
any
(
k
.
startswith
(
ldm_vae_key
)
for
k
in
keys
):
vae_key
=
ldm_vae_key
for
key
in
keys
:
for
key
in
keys
:
if
key
.
startswith
(
vae_key
):
if
key
.
startswith
(
vae_key
):
vae_state_dict
[
key
.
replace
(
vae_key
,
""
)]
=
checkpoint
.
get
(
key
)
vae_state_dict
[
key
.
replace
(
vae_key
,
""
)]
=
checkpoint
.
get
(
key
)
...
@@ -1896,6 +1904,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
...
@@ -1896,6 +1904,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
def
convert_flux_transformer_checkpoint_to_diffusers
(
checkpoint
,
**
kwargs
):
def
convert_flux_transformer_checkpoint_to_diffusers
(
checkpoint
,
**
kwargs
):
converted_state_dict
=
{}
converted_state_dict
=
{}
keys
=
list
(
checkpoint
.
keys
())
for
k
in
keys
:
if
"model.diffusion_model."
in
k
:
checkpoint
[
k
.
replace
(
"model.diffusion_model."
,
""
)]
=
checkpoint
.
pop
(
k
)
num_layers
=
list
(
set
(
int
(
k
.
split
(
"."
,
2
)[
1
])
for
k
in
checkpoint
if
"double_blocks."
in
k
))[
-
1
]
+
1
# noqa: C401
num_layers
=
list
(
set
(
int
(
k
.
split
(
"."
,
2
)[
1
])
for
k
in
checkpoint
if
"double_blocks."
in
k
))[
-
1
]
+
1
# noqa: C401
num_single_layers
=
list
(
set
(
int
(
k
.
split
(
"."
,
2
)[
1
])
for
k
in
checkpoint
if
"single_blocks."
in
k
))[
-
1
]
+
1
# noqa: C401
num_single_layers
=
list
(
set
(
int
(
k
.
split
(
"."
,
2
)[
1
])
for
k
in
checkpoint
if
"single_blocks."
in
k
))[
-
1
]
+
1
# noqa: C401
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment