Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
d91f45ef
"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "3e58e5170068b1dc0bdbfd6baa9e69dec96a9f44"
Commit
d91f45ef
authored
Feb 19, 2024
by
comfyanonymous
Browse files
Some cleanups to how the text encoders are loaded.
parent
dbe0979b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
28 deletions
+32
-28
comfy/sd.py
comfy/sd.py
+13
-10
comfy/supported_models.py
comfy/supported_models.py
+15
-16
comfy/supported_models_base.py
comfy/supported_models_base.py
+4
-2
No files found.
comfy/sd.py
View file @
d91f45ef
...
...
@@ -138,8 +138,11 @@ class CLIP:
tokens
=
self
.
tokenize
(
text
)
return
self
.
encode_from_tokens
(
tokens
)
def
load_sd
(
self
,
sd
):
return
self
.
cond_stage_model
.
load_sd
(
sd
)
def
load_sd
(
self
,
sd
,
full_model
=
False
):
if
full_model
:
return
self
.
cond_stage_model
.
load_state_dict
(
sd
,
strict
=
False
)
else
:
return
self
.
cond_stage_model
.
load_sd
(
sd
)
def
get_sd
(
self
):
return
self
.
cond_stage_model
.
state_dict
()
...
...
@@ -494,9 +497,6 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
parameters
=
comfy
.
utils
.
calculate_parameters
(
sd
,
"model.diffusion_model."
)
load_device
=
model_management
.
get_torch_device
()
class
WeightsLoader
(
torch
.
nn
.
Module
):
pass
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
"model.diffusion_model."
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
,
supported_dtypes
=
model_config
.
supported_inference_dtypes
)
manual_cast_dtype
=
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
,
model_config
.
supported_inference_dtypes
)
...
...
@@ -521,14 +521,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
vae
=
VAE
(
sd
=
vae_sd
)
if
output_clip
:
w
=
WeightsLoader
()
clip_target
=
model_config
.
clip_target
()
if
clip_target
is
not
None
:
sd
=
model_config
.
process_clip_state_dict
(
sd
)
if
any
(
k
.
startswith
(
'cond_stage_model.'
)
for
k
in
sd
)
:
clip_
sd
=
model_config
.
process_clip_state_dict
(
sd
)
if
len
(
clip_sd
)
>
0
:
clip
=
CLIP
(
clip_target
,
embedding_directory
=
embedding_directory
)
w
.
cond_stage_model
=
clip
.
cond_stage_model
load_model_weights
(
w
,
sd
)
m
,
u
=
clip
.
load_sd
(
clip_sd
,
full_model
=
True
)
if
len
(
m
)
>
0
:
print
(
"clip missing:"
,
m
)
if
len
(
u
)
>
0
:
print
(
"clip unexpected:"
,
u
)
else
:
print
(
"no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded."
)
...
...
comfy/supported_models.py
View file @
d91f45ef
...
...
@@ -40,8 +40,8 @@ class SD15(supported_models_base.BASE):
state_dict
[
'cond_stage_model.transformer.text_model.embeddings.position_ids'
]
=
ids
.
round
()
replace_prefix
=
{}
replace_prefix
[
"cond_stage_model."
]
=
"
cond_stage_model.
clip_l."
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
replace_prefix
[
"cond_stage_model."
]
=
"clip_l."
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
,
filter_keys
=
True
)
return
state_dict
def
process_clip_state_dict_for_saving
(
self
,
state_dict
):
...
...
@@ -72,10 +72,10 @@ class SD20(supported_models_base.BASE):
def
process_clip_state_dict
(
self
,
state_dict
):
replace_prefix
=
{}
replace_prefix
[
"conditioner.embedders.0.model."
]
=
"c
ond_stage_model.model
."
#SD2 in sgm format
state_dict
=
utils
.
state_dict_prefix_replace
(
sta
t
e_
dict
,
replace_prefix
)
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"c
ond_stage_model.model."
,
"cond_stage_model.
clip_h.transformer.text_model."
,
24
)
replace_prefix
[
"conditioner.embedders.0.model."
]
=
"c
lip_h
."
#SD2 in sgm format
replace_prefix
[
"cond_
sta
g
e_
model.model."
]
=
"clip_h."
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
,
filter_keys
=
True
)
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"c
lip_h."
,
"
clip_h.transformer.text_model."
,
24
)
return
state_dict
def
process_clip_state_dict_for_saving
(
self
,
state_dict
):
...
...
@@ -131,11 +131,10 @@ class SDXLRefiner(supported_models_base.BASE):
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
replace_prefix
=
{}
replace_prefix
[
"conditioner.embedders.0.model."
]
=
"clip_g."
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
,
filter_keys
=
True
)
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"conditioner.embedders.0.model."
,
"cond_stage_model.clip_g.transformer.text_model."
,
32
)
keys_to_replace
[
"conditioner.embedders.0.model.text_projection"
]
=
"cond_stage_model.clip_g.text_projection"
keys_to_replace
[
"conditioner.embedders.0.model.logit_scale"
]
=
"cond_stage_model.clip_g.logit_scale"
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"clip_g."
,
"clip_g.transformer.text_model."
,
32
)
state_dict
=
utils
.
state_dict_key_replace
(
state_dict
,
keys_to_replace
)
return
state_dict
...
...
@@ -179,13 +178,13 @@ class SDXL(supported_models_base.BASE):
keys_to_replace
=
{}
replace_prefix
=
{}
replace_prefix
[
"conditioner.embedders.0.transformer.text_model"
]
=
"cond_stage_model.clip_l.transformer.text_model"
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"conditioner.embedders.1.model."
,
"cond_stage_model.clip_g.transformer.text_model."
,
32
)
keys_to_replace
[
"conditioner.embedders.1.model.text_projection"
]
=
"cond_stage_model.clip_g.text_projection"
keys_to_replace
[
"conditioner.embedders.1.model.text_projection.weight"
]
=
"cond_stage_model.clip_g.text_projection"
keys_to_replace
[
"conditioner.embedders.1.model.logit_scale"
]
=
"cond_stage_model.clip_g.logit_scale"
replace_prefix
[
"conditioner.embedders.0.transformer.text_model"
]
=
"clip_l.transformer.text_model"
replace_prefix
[
"conditioner.embedders.1.model."
]
=
"clip_g."
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
,
filter_keys
=
True
)
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"clip_g."
,
"clip_g.transformer.text_model."
,
32
)
keys_to_replace
[
"clip_g.text_projection.weight"
]
=
"clip_g.text_projection"
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
state_dict
=
utils
.
state_dict_key_replace
(
state_dict
,
keys_to_replace
)
return
state_dict
...
...
comfy/supported_models_base.py
View file @
d91f45ef
...
...
@@ -22,6 +22,7 @@ class BASE:
sampling_settings
=
{}
latent_format
=
latent_formats
.
LatentFormat
vae_key_prefix
=
[
"first_stage_model."
]
text_encoder_key_prefix
=
[
"cond_stage_model."
]
supported_inference_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
manual_cast_dtype
=
None
...
...
@@ -55,6 +56,7 @@ class BASE:
return
out
def
process_clip_state_dict
(
self
,
state_dict
):
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
{
k
:
""
for
k
in
self
.
text_encoder_key_prefix
},
filter_keys
=
True
)
return
state_dict
def
process_unet_state_dict
(
self
,
state_dict
):
...
...
@@ -64,7 +66,7 @@ class BASE:
return
state_dict
def
process_clip_state_dict_for_saving
(
self
,
state_dict
):
replace_prefix
=
{
""
:
"cond_stage_model."
}
replace_prefix
=
{
""
:
self
.
text_encoder_key_prefix
[
0
]
}
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
def
process_clip_vision_state_dict_for_saving
(
self
,
state_dict
):
...
...
@@ -78,7 +80,7 @@ class BASE:
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
def
process_vae_state_dict_for_saving
(
self
,
state_dict
):
replace_prefix
=
{
""
:
"first_stage_model."
}
replace_prefix
=
{
""
:
self
.
vae_key_prefix
[
0
]
}
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
def
set_inference_dtype
(
self
,
dtype
,
manual_cast_dtype
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment