Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
97d03ae0
"vscode:/vscode.git/clone" did not exist on "3b5badb770ad2b91cac4e046e34adf163e8cbf21"
Commit
97d03ae0
authored
Feb 16, 2024
by
comfyanonymous
Browse files
StableCascade CLIP model support.
parent
667c9281
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
8 deletions
+43
-8
comfy/sd.py
comfy/sd.py
+11
-3
comfy/sd1_clip.py
comfy/sd1_clip.py
+2
-2
comfy/sdxl_clip.py
comfy/sdxl_clip.py
+22
-0
comfy/supported_models.py
comfy/supported_models.py
+1
-1
nodes.py
nodes.py
+7
-2
No files found.
comfy/sd.py
View file @
97d03ae0
import
torch
import
torch
from
enum
import
Enum
from
comfy
import
model_management
from
comfy
import
model_management
from
.ldm.models.autoencoder
import
AutoencoderKL
,
AutoencodingEngine
from
.ldm.models.autoencoder
import
AutoencoderKL
,
AutoencodingEngine
...
@@ -309,8 +310,11 @@ def load_style_model(ckpt_path):
...
@@ -309,8 +310,11 @@ def load_style_model(ckpt_path):
model
.
load_state_dict
(
model_data
)
model
.
load_state_dict
(
model_data
)
return
StyleModel
(
model
)
return
StyleModel
(
model
)
class
CLIPType
(
Enum
):
STABLE_DIFFUSION
=
1
STABLE_CASCADE
=
2
def
load_clip
(
ckpt_paths
,
embedding_directory
=
None
):
def
load_clip
(
ckpt_paths
,
embedding_directory
=
None
,
clip_type
=
CLIPType
.
STABLE_DIFFUSION
):
clip_data
=
[]
clip_data
=
[]
for
p
in
ckpt_paths
:
for
p
in
ckpt_paths
:
clip_data
.
append
(
comfy
.
utils
.
load_torch_file
(
p
,
safe_load
=
True
))
clip_data
.
append
(
comfy
.
utils
.
load_torch_file
(
p
,
safe_load
=
True
))
...
@@ -326,8 +330,12 @@ def load_clip(ckpt_paths, embedding_directory=None):
...
@@ -326,8 +330,12 @@ def load_clip(ckpt_paths, embedding_directory=None):
clip_target
.
params
=
{}
clip_target
.
params
=
{}
if
len
(
clip_data
)
==
1
:
if
len
(
clip_data
)
==
1
:
if
"text_model.encoder.layers.30.mlp.fc1.weight"
in
clip_data
[
0
]:
if
"text_model.encoder.layers.30.mlp.fc1.weight"
in
clip_data
[
0
]:
clip_target
.
clip
=
sdxl_clip
.
SDXLRefinerClipModel
if
clip_type
==
CLIPType
.
STABLE_CASCADE
:
clip_target
.
tokenizer
=
sdxl_clip
.
SDXLTokenizer
clip_target
.
clip
=
sdxl_clip
.
StableCascadeClipModel
clip_target
.
tokenizer
=
sdxl_clip
.
StableCascadeTokenizer
else
:
clip_target
.
clip
=
sdxl_clip
.
SDXLRefinerClipModel
clip_target
.
tokenizer
=
sdxl_clip
.
SDXLTokenizer
elif
"text_model.encoder.layers.22.mlp.fc1.weight"
in
clip_data
[
0
]:
elif
"text_model.encoder.layers.22.mlp.fc1.weight"
in
clip_data
[
0
]:
clip_target
.
clip
=
sd2_clip
.
SD2ClipModel
clip_target
.
clip
=
sd2_clip
.
SD2ClipModel
clip_target
.
tokenizer
=
sd2_clip
.
SD2Tokenizer
clip_target
.
tokenizer
=
sd2_clip
.
SD2Tokenizer
...
...
comfy/sd1_clip.py
View file @
97d03ae0
...
@@ -67,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -67,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
]
]
def
__init__
(
self
,
version
=
"openai/clip-vit-large-patch14"
,
device
=
"cpu"
,
max_length
=
77
,
def
__init__
(
self
,
version
=
"openai/clip-vit-large-patch14"
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
,
layer_idx
=
None
,
textmodel_json_config
=
None
,
dtype
=
None
,
model_class
=
comfy
.
clip_model
.
CLIPTextModel
,
freeze
=
True
,
layer
=
"last"
,
layer_idx
=
None
,
textmodel_json_config
=
None
,
dtype
=
None
,
model_class
=
comfy
.
clip_model
.
CLIPTextModel
,
special_tokens
=
{
"start"
:
49406
,
"end"
:
49407
,
"pad"
:
49407
},
layer_norm_hidden_state
=
True
):
# clip-vit-base-patch32
special_tokens
=
{
"start"
:
49406
,
"end"
:
49407
,
"pad"
:
49407
},
layer_norm_hidden_state
=
True
,
enable_attention_masks
=
False
):
# clip-vit-base-patch32
super
().
__init__
()
super
().
__init__
()
assert
layer
in
self
.
LAYERS
assert
layer
in
self
.
LAYERS
...
@@ -88,7 +88,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -88,7 +88,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self
.
special_tokens
=
special_tokens
self
.
special_tokens
=
special_tokens
self
.
text_projection
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
self
.
transformer
.
get_input_embeddings
().
weight
.
shape
[
1
]))
self
.
text_projection
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
self
.
transformer
.
get_input_embeddings
().
weight
.
shape
[
1
]))
self
.
logit_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
4.6055
))
self
.
logit_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
4.6055
))
self
.
enable_attention_masks
=
False
self
.
enable_attention_masks
=
enable_attention_masks
self
.
layer_norm_hidden_state
=
layer_norm_hidden_state
self
.
layer_norm_hidden_state
=
layer_norm_hidden_state
if
layer
==
"hidden"
:
if
layer
==
"hidden"
:
...
...
comfy/sdxl_clip.py
View file @
97d03ae0
...
@@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module):
...
@@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module):
class
SDXLRefinerClipModel
(
sd1_clip
.
SD1ClipModel
):
class
SDXLRefinerClipModel
(
sd1_clip
.
SD1ClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
super
().
__init__
(
device
=
device
,
dtype
=
dtype
,
clip_name
=
"g"
,
clip_model
=
SDXLClipG
)
super
().
__init__
(
device
=
device
,
dtype
=
dtype
,
clip_name
=
"g"
,
clip_model
=
SDXLClipG
)
class
StableCascadeClipGTokenizer
(
sd1_clip
.
SDTokenizer
):
def
__init__
(
self
,
tokenizer_path
=
None
,
embedding_directory
=
None
):
super
().
__init__
(
tokenizer_path
,
pad_with_end
=
True
,
embedding_directory
=
embedding_directory
,
embedding_size
=
1280
,
embedding_key
=
'clip_g'
)
class
StableCascadeTokenizer
(
sd1_clip
.
SD1Tokenizer
):
def
__init__
(
self
,
embedding_directory
=
None
):
super
().
__init__
(
embedding_directory
=
embedding_directory
,
clip_name
=
"g"
,
tokenizer
=
StableCascadeClipGTokenizer
)
class
StableCascadeClipG
(
sd1_clip
.
SDClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"hidden"
,
layer_idx
=-
1
,
dtype
=
None
):
textmodel_json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"clip_config_bigg.json"
)
super
().
__init__
(
device
=
device
,
freeze
=
freeze
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
textmodel_json_config
,
dtype
=
dtype
,
special_tokens
=
{
"start"
:
49406
,
"end"
:
49407
,
"pad"
:
49407
},
layer_norm_hidden_state
=
False
,
enable_attention_masks
=
True
)
def
load_sd
(
self
,
sd
):
return
super
().
load_sd
(
sd
)
class
StableCascadeClipModel
(
sd1_clip
.
SD1ClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
super
().
__init__
(
device
=
device
,
dtype
=
dtype
,
clip_name
=
"g"
,
clip_model
=
StableCascadeClipG
)
comfy/supported_models.py
View file @
97d03ae0
...
@@ -336,7 +336,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
...
@@ -336,7 +336,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
return
out
return
out
def
clip_target
(
self
):
def
clip_target
(
self
):
return
None
return
supported_models_base
.
ClipTarget
(
sdxl_clip
.
StableCascadeTokenizer
,
sdxl_clip
.
StableCascadeClipModel
)
class
Stable_Cascade_B
(
Stable_Cascade_C
):
class
Stable_Cascade_B
(
Stable_Cascade_C
):
unet_config
=
{
unet_config
=
{
...
...
nodes.py
View file @
97d03ae0
...
@@ -854,15 +854,20 @@ class CLIPLoader:
...
@@ -854,15 +854,20 @@ class CLIPLoader:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"clip_name"
:
(
folder_paths
.
get_filename_list
(
"clip"
),
),
return
{
"required"
:
{
"clip_name"
:
(
folder_paths
.
get_filename_list
(
"clip"
),
),
"type"
:
([
"stable_diffusion"
,
"stable_cascade"
],
),
}}
}}
RETURN_TYPES
=
(
"CLIP"
,)
RETURN_TYPES
=
(
"CLIP"
,)
FUNCTION
=
"load_clip"
FUNCTION
=
"load_clip"
CATEGORY
=
"advanced/loaders"
CATEGORY
=
"advanced/loaders"
def
load_clip
(
self
,
clip_name
):
def
load_clip
(
self
,
clip_name
,
type
=
"stable_diffusion"
):
clip_type
=
comfy
.
sd
.
CLIPType
.
STABLE_DIFFUSION
if
type
==
"stable_cascade"
:
clip_type
=
comfy
.
sd
.
CLIPType
.
STABLE_CASCADE
clip_path
=
folder_paths
.
get_full_path
(
"clip"
,
clip_name
)
clip_path
=
folder_paths
.
get_full_path
(
"clip"
,
clip_name
)
clip
=
comfy
.
sd
.
load_clip
(
ckpt_paths
=
[
clip_path
],
embedding_directory
=
folder_paths
.
get_folder_paths
(
"embeddings"
))
clip
=
comfy
.
sd
.
load_clip
(
ckpt_paths
=
[
clip_path
],
embedding_directory
=
folder_paths
.
get_folder_paths
(
"embeddings"
)
,
clip_type
=
clip_type
)
return
(
clip
,)
return
(
clip
,)
class
DualCLIPLoader
:
class
DualCLIPLoader
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment