Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
20f579d9
Commit
20f579d9
authored
Jun 25, 2023
by
comfyanonymous
Browse files
Add DualClipLoader to load clip models for SDXL.
Update LoadClip to load clip models for SDXL refiner.
parent
b7933960
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
11 deletions
+67
-11
comfy/sd.py
comfy/sd.py
+32
-9
comfy/sd1_clip.py
comfy/sd1_clip.py
+3
-0
comfy/sdxl_clip.py
comfy/sdxl_clip.py
+13
-0
nodes.py
nodes.py
+19
-2
No files found.
comfy/sd.py
View file @
20f579d9
...
@@ -19,6 +19,7 @@ from . import model_detection
...
@@ -19,6 +19,7 @@ from . import model_detection
from
.
import
sd1_clip
from
.
import
sd1_clip
from
.
import
sd2_clip
from
.
import
sd2_clip
from
.
import
sdxl_clip
def
load_model_weights
(
model
,
sd
):
def
load_model_weights
(
model
,
sd
):
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
...
@@ -524,7 +525,7 @@ class CLIP:
...
@@ -524,7 +525,7 @@ class CLIP:
return
n
return
n
def
load_from_state_dict
(
self
,
sd
):
def
load_from_state_dict
(
self
,
sd
):
self
.
cond_stage_model
.
transformer
.
load_state_dict
(
sd
,
strict
=
False
)
self
.
cond_stage_model
.
load_sd
(
sd
)
def
add_patches
(
self
,
patches
,
strength
=
1.0
):
def
add_patches
(
self
,
patches
,
strength
=
1.0
):
return
self
.
patcher
.
add_patches
(
patches
,
strength
)
return
self
.
patcher
.
add_patches
(
patches
,
strength
)
...
@@ -555,6 +556,8 @@ class CLIP:
...
@@ -555,6 +556,8 @@ class CLIP:
tokens
=
self
.
tokenize
(
text
)
tokens
=
self
.
tokenize
(
text
)
return
self
.
encode_from_tokens
(
tokens
)
return
self
.
encode_from_tokens
(
tokens
)
def
load_sd
(
self
,
sd
):
return
self
.
cond_stage_model
.
load_sd
(
sd
)
class
VAE
:
class
VAE
:
def
__init__
(
self
,
ckpt_path
=
None
,
device
=
None
,
config
=
None
):
def
__init__
(
self
,
ckpt_path
=
None
,
device
=
None
,
config
=
None
):
...
@@ -959,22 +962,42 @@ def load_style_model(ckpt_path):
...
@@ -959,22 +962,42 @@ def load_style_model(ckpt_path):
return
StyleModel
(
model
)
return
StyleModel
(
model
)
def
load_clip
(
ckpt_path
,
embedding_directory
=
None
):
def
load_clip
(
ckpt_paths
,
embedding_directory
=
None
):
clip_data
=
utils
.
load_torch_file
(
ckpt_path
,
safe_load
=
True
)
clip_data
=
[]
for
p
in
ckpt_paths
:
clip_data
.
append
(
utils
.
load_torch_file
(
p
,
safe_load
=
True
))
class
EmptyClass
:
class
EmptyClass
:
pass
pass
for
i
in
range
(
len
(
clip_data
)):
if
"transformer.resblocks.0.ln_1.weight"
in
clip_data
[
i
]:
clip_data
[
i
]
=
utils
.
transformers_convert
(
clip_data
[
i
],
""
,
"text_model."
,
32
)
clip_target
=
EmptyClass
()
clip_target
=
EmptyClass
()
clip_target
.
params
=
{}
clip_target
.
params
=
{}
if
"text_model.encoder.layers.22.mlp.fc1.weight"
in
clip_data
:
if
len
(
clip_data
)
==
1
:
if
"text_model.encoder.layers.30.mlp.fc1.weight"
in
clip_data
[
0
]:
clip_target
.
clip
=
sdxl_clip
.
SDXLRefinerClipModel
clip_target
.
tokenizer
=
sdxl_clip
.
SDXLTokenizer
elif
"text_model.encoder.layers.22.mlp.fc1.weight"
in
clip_data
[
0
]:
clip_target
.
clip
=
sd2_clip
.
SD2ClipModel
clip_target
.
clip
=
sd2_clip
.
SD2ClipModel
clip_target
.
tokenizer
=
sd2_clip
.
SD2Tokenizer
clip_target
.
tokenizer
=
sd2_clip
.
SD2Tokenizer
else
:
else
:
clip_target
.
clip
=
sd1_clip
.
SD1ClipModel
clip_target
.
clip
=
sd1_clip
.
SD1ClipModel
clip_target
.
tokenizer
=
sd1_clip
.
SD1Tokenizer
clip_target
.
tokenizer
=
sd1_clip
.
SD1Tokenizer
else
:
clip_target
.
clip
=
sdxl_clip
.
SDXLClipModel
clip_target
.
tokenizer
=
sdxl_clip
.
SDXLTokenizer
clip
=
CLIP
(
clip_target
,
embedding_directory
=
embedding_directory
)
clip
=
CLIP
(
clip_target
,
embedding_directory
=
embedding_directory
)
clip
.
load_from_state_dict
(
clip_data
)
for
c
in
clip_data
:
m
,
u
=
clip
.
load_sd
(
c
)
if
len
(
m
)
>
0
:
print
(
"clip missing:"
,
m
)
if
len
(
u
)
>
0
:
print
(
"clip unexpected:"
,
u
)
return
clip
return
clip
def
load_gligen
(
ckpt_path
):
def
load_gligen
(
ckpt_path
):
...
...
comfy/sd1_clip.py
View file @
20f579d9
...
@@ -128,6 +128,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -128,6 +128,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def
encode
(
self
,
tokens
):
def
encode
(
self
,
tokens
):
return
self
(
tokens
)
return
self
(
tokens
)
def
load_sd
(
self
,
sd
):
return
self
.
transformer
.
load_state_dict
(
sd
,
strict
=
False
)
def
parse_parentheses
(
string
):
def
parse_parentheses
(
string
):
result
=
[]
result
=
[]
current_item
=
""
current_item
=
""
...
...
comfy/sdxl_clip.py
View file @
20f579d9
...
@@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
...
@@ -31,6 +31,11 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
self
.
layer
=
"hidden"
self
.
layer
=
"hidden"
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
def
load_sd
(
self
,
sd
):
if
"text_projection"
in
sd
:
self
.
text_projection
[:]
=
sd
.
pop
(
"text_projection"
)
return
super
().
load_sd
(
sd
)
class
SDXLClipGTokenizer
(
sd1_clip
.
SD1Tokenizer
):
class
SDXLClipGTokenizer
(
sd1_clip
.
SD1Tokenizer
):
def
__init__
(
self
,
tokenizer_path
=
None
,
embedding_directory
=
None
):
def
__init__
(
self
,
tokenizer_path
=
None
,
embedding_directory
=
None
):
super
().
__init__
(
tokenizer_path
,
pad_with_end
=
False
,
embedding_directory
=
embedding_directory
,
embedding_size
=
1280
)
super
().
__init__
(
tokenizer_path
,
pad_with_end
=
False
,
embedding_directory
=
embedding_directory
,
embedding_size
=
1280
)
...
@@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module):
...
@@ -68,6 +73,12 @@ class SDXLClipModel(torch.nn.Module):
l_out
,
l_pooled
=
self
.
clip_l
.
encode_token_weights
(
token_weight_pairs_l
)
l_out
,
l_pooled
=
self
.
clip_l
.
encode_token_weights
(
token_weight_pairs_l
)
return
torch
.
cat
([
l_out
,
g_out
],
dim
=-
1
),
g_pooled
return
torch
.
cat
([
l_out
,
g_out
],
dim
=-
1
),
g_pooled
def
load_sd
(
self
,
sd
):
if
"text_model.encoder.layers.30.mlp.fc1.weight"
in
sd
:
return
self
.
clip_g
.
load_sd
(
sd
)
else
:
return
self
.
clip_l
.
load_sd
(
sd
)
class
SDXLRefinerClipModel
(
torch
.
nn
.
Module
):
class
SDXLRefinerClipModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
device
=
"cpu"
):
def
__init__
(
self
,
device
=
"cpu"
):
super
().
__init__
()
super
().
__init__
()
...
@@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module):
...
@@ -81,3 +92,5 @@ class SDXLRefinerClipModel(torch.nn.Module):
g_out
,
g_pooled
=
self
.
clip_g
.
encode_token_weights
(
token_weight_pairs_g
)
g_out
,
g_pooled
=
self
.
clip_g
.
encode_token_weights
(
token_weight_pairs_g
)
return
g_out
,
g_pooled
return
g_out
,
g_pooled
def
load_sd
(
self
,
sd
):
return
self
.
clip_g
.
load_sd
(
sd
)
nodes.py
View file @
20f579d9
...
@@ -520,11 +520,27 @@ class CLIPLoader:
...
@@ -520,11 +520,27 @@ class CLIPLoader:
RETURN_TYPES
=
(
"CLIP"
,)
RETURN_TYPES
=
(
"CLIP"
,)
FUNCTION
=
"load_clip"
FUNCTION
=
"load_clip"
CATEGORY
=
"loaders"
CATEGORY
=
"
advanced/
loaders"
def
load_clip
(
self
,
clip_name
):
def
load_clip
(
self
,
clip_name
):
clip_path
=
folder_paths
.
get_full_path
(
"clip"
,
clip_name
)
clip_path
=
folder_paths
.
get_full_path
(
"clip"
,
clip_name
)
clip
=
comfy
.
sd
.
load_clip
(
ckpt_path
=
clip_path
,
embedding_directory
=
folder_paths
.
get_folder_paths
(
"embeddings"
))
clip
=
comfy
.
sd
.
load_clip
(
ckpt_paths
=
[
clip_path
],
embedding_directory
=
folder_paths
.
get_folder_paths
(
"embeddings"
))
return
(
clip
,)
class
DualCLIPLoader
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"clip_name1"
:
(
folder_paths
.
get_filename_list
(
"clip"
),
),
"clip_name2"
:
(
folder_paths
.
get_filename_list
(
"clip"
),
),
}}
RETURN_TYPES
=
(
"CLIP"
,)
FUNCTION
=
"load_clip"
CATEGORY
=
"advanced/loaders"
def
load_clip
(
self
,
clip_name1
,
clip_name2
):
clip_path1
=
folder_paths
.
get_full_path
(
"clip"
,
clip_name1
)
clip_path2
=
folder_paths
.
get_full_path
(
"clip"
,
clip_name2
)
clip
=
comfy
.
sd
.
load_clip
(
ckpt_paths
=
[
clip_path1
,
clip_path2
],
embedding_directory
=
folder_paths
.
get_folder_paths
(
"embeddings"
))
return
(
clip
,)
return
(
clip
,)
class
CLIPVisionLoader
:
class
CLIPVisionLoader
:
...
@@ -1315,6 +1331,7 @@ NODE_CLASS_MAPPINGS = {
...
@@ -1315,6 +1331,7 @@ NODE_CLASS_MAPPINGS = {
"LatentCrop"
:
LatentCrop
,
"LatentCrop"
:
LatentCrop
,
"LoraLoader"
:
LoraLoader
,
"LoraLoader"
:
LoraLoader
,
"CLIPLoader"
:
CLIPLoader
,
"CLIPLoader"
:
CLIPLoader
,
"DualCLIPLoader"
:
DualCLIPLoader
,
"CLIPVisionEncode"
:
CLIPVisionEncode
,
"CLIPVisionEncode"
:
CLIPVisionEncode
,
"StyleModelApply"
:
StyleModelApply
,
"StyleModelApply"
:
StyleModelApply
,
"unCLIPConditioning"
:
unCLIPConditioning
,
"unCLIPConditioning"
:
unCLIPConditioning
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment