Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
5889b7ca
Commit
5889b7ca
authored
Jun 11, 2024
by
comfyanonymous
Browse files
Support multiple text encoder configurations on SD3.
parent
1c34d338
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
33 deletions
+89
-33
comfy/sd.py
comfy/sd.py
+1
-1
comfy/sd3_clip.py
comfy/sd3_clip.py
+63
-22
comfy/supported_models.py
comfy/supported_models.py
+25
-10
No files found.
comfy/sd.py
View file @
5889b7ca
...
...
@@ -482,7 +482,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
vae
=
VAE
(
sd
=
vae_sd
)
if
output_clip
:
clip_target
=
model_config
.
clip_target
()
clip_target
=
model_config
.
clip_target
(
state_dict
=
sd
)
if
clip_target
is
not
None
:
clip_sd
=
model_config
.
process_clip_state_dict
(
sd
)
if
len
(
clip_sd
)
>
0
:
...
...
comfy/sd3_clip.py
View file @
5889b7ca
...
...
@@ -5,6 +5,7 @@ import comfy.t5
import
torch
import
os
import
comfy.model_management
import
logging
class
T5XXLModel
(
sd1_clip
.
SDClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
layer
=
"last"
,
layer_idx
=
None
,
dtype
=
None
):
...
...
@@ -43,20 +44,39 @@ class SD3Tokenizer:
return
self
.
clip_g
.
untokenize
(
token_weight_pair
)
class
SD3ClipModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
def
__init__
(
self
,
clip_l
=
True
,
clip_g
=
True
,
t5
=
True
,
device
=
"cpu"
,
dtype
=
None
):
super
().
__init__
()
if
clip_l
:
self
.
clip_l
=
sd1_clip
.
SDClipModel
(
layer
=
"hidden"
,
layer_idx
=-
2
,
device
=
device
,
dtype
=
dtype
,
layer_norm_hidden_state
=
False
,
return_projected_pooled
=
False
)
else
:
self
.
clip_l
=
None
if
clip_g
:
self
.
clip_g
=
sdxl_clip
.
SDXLClipG
(
device
=
device
,
dtype
=
dtype
)
else
:
self
.
clip_g
=
None
if
t5
:
self
.
t5xxl
=
T5XXLModel
(
device
=
device
,
dtype
=
dtype
)
else
:
self
.
t5xxl
=
None
logging
.
debug
(
"Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}"
.
format
(
clip_l
,
clip_g
,
t5
))
def
set_clip_options
(
self
,
options
):
if
self
.
clip_l
is
not
None
:
self
.
clip_l
.
set_clip_options
(
options
)
if
self
.
clip_g
is
not
None
:
self
.
clip_g
.
set_clip_options
(
options
)
if
self
.
t5xxl
is
not
None
:
self
.
t5xxl
.
set_clip_options
(
options
)
def
reset_clip_options
(
self
):
self
.
clip_
g
.
reset_clip_options
()
if
self
.
clip_
l
is
not
None
:
self
.
clip_l
.
reset_clip_options
()
if
self
.
clip_g
is
not
None
:
self
.
clip_g
.
reset_clip_options
()
if
self
.
t5xxl
is
not
None
:
self
.
t5xxl
.
reset_clip_options
()
def
encode_token_weights
(
self
,
token_weight_pairs
):
...
...
@@ -64,22 +84,43 @@ class SD3ClipModel(torch.nn.Module):
token_weight_pairs_g
=
token_weight_pairs
[
"g"
]
token_weight_pars_t5
=
token_weight_pairs
[
"t5xxl"
]
lg_out
=
None
pooled
=
None
out
=
None
if
len
(
token_weight_pairs_g
)
>
0
or
len
(
token_weight_pairs_l
)
>
0
:
l_out
,
l_pooled
=
self
.
clip_l
.
encode_token_weights
(
token_weight_pairs_l
)
if
self
.
clip_l
is
not
None
:
lg_out
,
l_pooled
=
self
.
clip_l
.
encode_token_weights
(
token_weight_pairs_l
)
else
:
l_pooled
=
torch
.
zeros
((
1
,
768
),
device
=
comfy
.
model_management
.
intermediate_device
())
if
self
.
clip_g
is
not
None
:
g_out
,
g_pooled
=
self
.
clip_g
.
encode_token_weights
(
token_weight_pairs_g
)
lg_out
=
torch
.
cat
([
l_out
,
g_out
],
dim
=-
1
)
if
lg_out
is
not
None
:
lg_out
=
torch
.
cat
([
lg_out
,
g_out
],
dim
=-
1
)
else
:
lg_out
=
torch
.
nn
.
functional
.
pad
(
g_out
,
(
768
,
0
))
else
:
g_out
=
None
g_pooled
=
torch
.
zeros
((
1
,
1280
),
device
=
comfy
.
model_management
.
intermediate_device
())
if
lg_out
is
not
None
:
lg_out
=
torch
.
nn
.
functional
.
pad
(
lg_out
,
(
0
,
4096
-
lg_out
.
shape
[
-
1
]))
out
=
lg_out
pooled
=
torch
.
cat
((
l_pooled
,
g_pooled
),
dim
=-
1
)
else
:
pooled
=
torch
.
zeros
((
1
,
1280
+
768
),
device
=
comfy
.
model_management
.
intermediate_device
())
if
self
.
t5xxl
is
not
None
:
t5_out
,
t5_pooled
=
self
.
t5xxl
.
encode_token_weights
(
token_weight_pars_t5
)
if
lg_out
is
not
None
:
out
=
torch
.
cat
([
lg_out
,
t5_out
],
dim
=-
2
)
else
:
out
=
t5_out
if
out
is
None
:
out
=
torch
.
zeros
((
1
,
77
,
4096
),
device
=
comfy
.
model_management
.
intermediate_device
())
if
pooled
is
None
:
pooled
=
torch
.
zeros
((
1
,
768
+
1280
),
device
=
comfy
.
model_management
.
intermediate_device
())
return
out
,
pooled
def
load_sd
(
self
,
sd
):
...
...
comfy/supported_models.py
View file @
5889b7ca
...
...
@@ -54,7 +54,7 @@ class SD15(supported_models_base.BASE):
replace_prefix
=
{
"clip_l."
:
"cond_stage_model."
}
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
def
clip_target
(
self
):
def
clip_target
(
self
,
state_dict
=
{}
):
return
supported_models_base
.
ClipTarget
(
sd1_clip
.
SD1Tokenizer
,
sd1_clip
.
SD1ClipModel
)
class
SD20
(
supported_models_base
.
BASE
):
...
...
@@ -97,7 +97,7 @@ class SD20(supported_models_base.BASE):
state_dict
=
diffusers_convert
.
convert_text_enc_state_dict_v20
(
state_dict
)
return
state_dict
def
clip_target
(
self
):
def
clip_target
(
self
,
state_dict
=
{}
):
return
supported_models_base
.
ClipTarget
(
sd2_clip
.
SD2Tokenizer
,
sd2_clip
.
SD2ClipModel
)
class
SD21UnclipL
(
SD20
):
...
...
@@ -159,7 +159,7 @@ class SDXLRefiner(supported_models_base.BASE):
state_dict_g
=
utils
.
state_dict_prefix_replace
(
state_dict_g
,
replace_prefix
)
return
state_dict_g
def
clip_target
(
self
):
def
clip_target
(
self
,
state_dict
=
{}
):
return
supported_models_base
.
ClipTarget
(
sdxl_clip
.
SDXLTokenizer
,
sdxl_clip
.
SDXLRefinerClipModel
)
class
SDXL
(
supported_models_base
.
BASE
):
...
...
@@ -228,7 +228,7 @@ class SDXL(supported_models_base.BASE):
state_dict_g
=
utils
.
state_dict_prefix_replace
(
state_dict_g
,
replace_prefix
)
return
state_dict_g
def
clip_target
(
self
):
def
clip_target
(
self
,
state_dict
=
{}
):
return
supported_models_base
.
ClipTarget
(
sdxl_clip
.
SDXLTokenizer
,
sdxl_clip
.
SDXLClipModel
)
class
SSD1B
(
SDXL
):
...
...
@@ -299,7 +299,7 @@ class SVD_img2vid(supported_models_base.BASE):
out
=
model_base
.
SVD_img2vid
(
self
,
device
=
device
)
return
out
def
clip_target
(
self
):
def
clip_target
(
self
,
state_dict
=
{}
):
return
None
class
SV3D_u
(
SVD_img2vid
):
...
...
@@ -365,7 +365,7 @@ class Stable_Zero123(supported_models_base.BASE):
out
=
model_base
.
Stable_Zero123
(
self
,
device
=
device
,
cc_projection_weight
=
state_dict
[
"cc_projection.weight"
],
cc_projection_bias
=
state_dict
[
"cc_projection.bias"
])
return
out
def
clip_target
(
self
):
def
clip_target
(
self
,
state_dict
=
{}
):
return
None
class
SD_X4Upscaler
(
SD20
):
...
...
@@ -439,7 +439,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
out
=
model_base
.
StableCascade_C
(
self
,
device
=
device
)
return
out
def
clip_target
(
self
):
def
clip_target
(
self
,
state_dict
=
{}
):
return
supported_models_base
.
ClipTarget
(
sdxl_clip
.
StableCascadeTokenizer
,
sdxl_clip
.
StableCascadeClipModel
)
class
Stable_Cascade_B
(
Stable_Cascade_C
):
...
...
@@ -501,14 +501,29 @@ class SD3(supported_models_base.BASE):
unet_extra_config
=
{}
latent_format
=
latent_formats
.
SD3
text_encoder_key_prefix
=
[
"text_encoders."
]
#TODO?
text_encoder_key_prefix
=
[
"text_encoders."
]
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
out
=
model_base
.
SD3
(
self
,
device
=
device
)
return
out
def
clip_target
(
self
):
return
supported_models_base
.
ClipTarget
(
sd3_clip
.
SD3Tokenizer
,
sd3_clip
.
SD3ClipModel
)
#TODO?
def
clip_target
(
self
,
state_dict
=
{}):
clip_l
=
False
clip_g
=
False
t5
=
False
pref
=
self
.
text_encoder_key_prefix
[
0
]
if
"{}clip_l.transformer.text_model.final_layer_norm.weight"
.
format
(
pref
)
in
state_dict
:
clip_l
=
True
if
"{}clip_g.transformer.text_model.final_layer_norm.weight"
.
format
(
pref
)
in
state_dict
:
clip_g
=
True
if
"{}t5xxl.transformer.encoder.final_layer_norm.weight"
.
format
(
pref
)
in
state_dict
:
t5
=
True
class
SD3ClipModel
(
sd3_clip
.
SD3ClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
super
().
__init__
(
clip_l
=
clip_l
,
clip_g
=
clip_g
,
t5
=
t5
,
device
=
device
,
dtype
=
dtype
)
return
supported_models_base
.
ClipTarget
(
sd3_clip
.
SD3Tokenizer
,
SD3ClipModel
)
models
=
[
Stable_Zero123
,
SD15_instructpix2pix
,
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXL_instructpix2pix
,
SDXLRefiner
,
SDXL
,
SSD1B
,
KOALA_700M
,
KOALA_1B
,
Segmind_Vega
,
SD_X4Upscaler
,
Stable_Cascade_C
,
Stable_Cascade_B
,
SV3D_u
,
SV3D_p
,
SD3
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment