Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
50b1180d
Commit
50b1180d
authored
Jul 15, 2023
by
comfyanonymous
Browse files
Fix CLIPSetLastLayer not reverting when removed.
parent
6fb084f3
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
36 deletions
+28
-36
comfy/sd.py
comfy/sd.py
+2
-0
comfy/sd1_clip.py
comfy/sd1_clip.py
+9
-2
comfy/sd2_clip.py
comfy/sd2_clip.py
+5
-12
comfy/sdxl_clip.py
comfy/sdxl_clip.py
+12
-22
No files found.
comfy/sd.py
View file @
50b1180d
...
...
@@ -493,6 +493,8 @@ class CLIP:
def
encode_from_tokens
(
self
,
tokens
,
return_pooled
=
False
):
if
self
.
layer_idx
is
not
None
:
self
.
cond_stage_model
.
clip_layer
(
self
.
layer_idx
)
else
:
self
.
cond_stage_model
.
reset_clip_layer
()
model_management
.
load_model_gpu
(
self
.
patcher
)
cond
,
pooled
=
self
.
cond_stage_model
.
encode_token_weights
(
tokens
)
...
...
comfy/sd1_clip.py
View file @
50b1180d
...
...
@@ -46,12 +46,14 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
freeze
=
True
,
layer
=
"last"
,
layer_idx
=
None
,
textmodel_json_config
=
None
,
textmodel_path
=
None
):
# clip-vit-base-patch32
super
().
__init__
()
assert
layer
in
self
.
LAYERS
self
.
num_layers
=
12
if
textmodel_path
is
not
None
:
self
.
transformer
=
CLIPTextModel
.
from_pretrained
(
textmodel_path
)
else
:
if
textmodel_json_config
is
None
:
textmodel_json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"sd1_clip_config.json"
)
config
=
CLIPTextConfig
.
from_json_file
(
textmodel_json_config
)
self
.
num_layers
=
config
.
num_hidden_layers
with
comfy
.
ops
.
use_comfy_ops
():
with
modeling_utils
.
no_init_weights
():
self
.
transformer
=
CLIPTextModel
(
config
)
...
...
@@ -66,8 +68,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self
.
layer_norm_hidden_state
=
True
if
layer
==
"hidden"
:
assert
layer_idx
is
not
None
assert
abs
(
layer_idx
)
<=
12
assert
abs
(
layer_idx
)
<=
self
.
num_layers
self
.
clip_layer
(
layer_idx
)
self
.
layer_default
=
(
self
.
layer
,
self
.
layer_idx
)
def
freeze
(
self
):
self
.
transformer
=
self
.
transformer
.
eval
()
...
...
@@ -76,12 +79,16 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
param
.
requires_grad
=
False
def
clip_layer
(
self
,
layer_idx
):
if
abs
(
layer_idx
)
>=
12
:
if
abs
(
layer_idx
)
>=
self
.
num_layers
:
self
.
layer
=
"last"
else
:
self
.
layer
=
"hidden"
self
.
layer_idx
=
layer_idx
def
reset_clip_layer
(
self
):
self
.
layer
=
self
.
layer_default
[
0
]
self
.
layer_idx
=
self
.
layer_default
[
1
]
def
set_up_textual_embeddings
(
self
,
tokens
,
current_embeds
):
out_tokens
=
[]
next_new_token
=
token_dict_size
=
current_embeds
.
weight
.
shape
[
0
]
...
...
comfy/sd2_clip.py
View file @
50b1180d
...
...
@@ -4,20 +4,13 @@ import os
class
SD2ClipModel
(
sd1_clip
.
SD1ClipModel
):
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"penultimate"
,
layer_idx
=
None
,
textmodel_path
=
None
):
if
layer
==
"penultimate"
:
layer
=
"hidden"
layer_idx
=
23
textmodel_json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"sd2_clip_config.json"
)
super
().
__init__
(
device
=
device
,
freeze
=
freeze
,
textmodel_json_config
=
textmodel_json_config
,
textmodel_path
=
textmodel_path
)
super
().
__init__
(
device
=
device
,
freeze
=
freeze
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
textmodel_json_config
,
textmodel_path
=
textmodel_path
)
self
.
empty_tokens
=
[[
49406
]
+
[
49407
]
+
[
0
]
*
75
]
if
layer
==
"last"
:
pass
elif
layer
==
"penultimate"
:
layer_idx
=
-
1
self
.
clip_layer
(
layer_idx
)
elif
self
.
layer
==
"hidden"
:
assert
layer_idx
is
not
None
assert
abs
(
layer_idx
)
<
24
self
.
clip_layer
(
layer_idx
)
else
:
raise
NotImplementedError
()
def
clip_layer
(
self
,
layer_idx
):
if
layer_idx
<
0
:
...
...
comfy/sdxl_clip.py
View file @
50b1180d
...
...
@@ -4,33 +4,16 @@ import os
class
SDXLClipG
(
sd1_clip
.
SD1ClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"penultimate"
,
layer_idx
=
None
,
textmodel_path
=
None
):
if
layer
==
"penultimate"
:
layer
=
"hidden"
layer_idx
=-
2
textmodel_json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"clip_config_bigg.json"
)
super
().
__init__
(
device
=
device
,
freeze
=
freeze
,
textmodel_json_config
=
textmodel_json_config
,
textmodel_path
=
textmodel_path
)
super
().
__init__
(
device
=
device
,
freeze
=
freeze
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
textmodel_json_config
,
textmodel_path
=
textmodel_path
)
self
.
empty_tokens
=
[[
49406
]
+
[
49407
]
+
[
0
]
*
75
]
self
.
text_projection
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
1280
,
1280
))
self
.
logit_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
4.6055
))
self
.
layer_norm_hidden_state
=
False
if
layer
==
"last"
:
pass
elif
layer
==
"penultimate"
:
layer_idx
=
-
1
self
.
clip_layer
(
layer_idx
)
elif
self
.
layer
==
"hidden"
:
assert
layer_idx
is
not
None
assert
abs
(
layer_idx
)
<
32
self
.
clip_layer
(
layer_idx
)
else
:
raise
NotImplementedError
()
def
clip_layer
(
self
,
layer_idx
):
if
layer_idx
<
0
:
layer_idx
-=
1
#The real last layer of SD2.x clip is the penultimate one. The last one might contain garbage.
if
abs
(
layer_idx
)
>=
32
:
self
.
layer
=
"hidden"
self
.
layer_idx
=
-
2
else
:
self
.
layer
=
"hidden"
self
.
layer_idx
=
layer_idx
def
load_sd
(
self
,
sd
):
if
"text_projection"
in
sd
:
...
...
@@ -69,6 +52,10 @@ class SDXLClipModel(torch.nn.Module):
self
.
clip_l
.
clip_layer
(
layer_idx
)
self
.
clip_g
.
clip_layer
(
layer_idx
)
def
reset_clip_layer
(
self
):
self
.
clip_g
.
reset_clip_layer
()
self
.
clip_l
.
reset_clip_layer
()
def
encode_token_weights
(
self
,
token_weight_pairs
):
token_weight_pairs_g
=
token_weight_pairs
[
"g"
]
token_weight_pairs_l
=
token_weight_pairs
[
"l"
]
...
...
@@ -90,6 +77,9 @@ class SDXLRefinerClipModel(torch.nn.Module):
def
clip_layer
(
self
,
layer_idx
):
self
.
clip_g
.
clip_layer
(
layer_idx
)
def
reset_clip_layer
(
self
):
self
.
clip_g
.
reset_clip_layer
()
def
encode_token_weights
(
self
,
token_weight_pairs
):
token_weight_pairs_g
=
token_weight_pairs
[
"g"
]
g_out
,
g_pooled
=
self
.
clip_g
.
encode_token_weights
(
token_weight_pairs_g
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment