Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
c2cb8e88
You need to sign in or sign up before continuing.
Commit
c2cb8e88
authored
Feb 25, 2024
by
comfyanonymous
Browse files
Always return unprojected pooled output for gligen.
parent
1cb3f6a8
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
38 additions
and
28 deletions
+38
-28
comfy/clip_model.py
comfy/clip_model.py
+1
-1
comfy/lora.py
comfy/lora.py
+2
-2
comfy/sd.py
comfy/sd.py
+6
-3
comfy/sd1_clip.py
comfy/sd1_clip.py
+22
-15
comfy/sdxl_clip.py
comfy/sdxl_clip.py
+6
-6
nodes.py
nodes.py
+1
-1
No files found.
comfy/clip_model.py
View file @
c2cb8e88
...
@@ -133,7 +133,7 @@ class CLIPTextModel(torch.nn.Module):
...
@@ -133,7 +133,7 @@ class CLIPTextModel(torch.nn.Module):
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
x
=
self
.
text_model
(
*
args
,
**
kwargs
)
x
=
self
.
text_model
(
*
args
,
**
kwargs
)
out
=
self
.
text_projection
(
x
[
2
])
out
=
self
.
text_projection
(
x
[
2
])
return
(
x
[
0
],
x
[
1
],
out
)
return
(
x
[
0
],
x
[
1
],
out
,
x
[
2
]
)
class
CLIPVisionEmbeddings
(
torch
.
nn
.
Module
):
class
CLIPVisionEmbeddings
(
torch
.
nn
.
Module
):
...
...
comfy/lora.py
View file @
c2cb8e88
...
@@ -201,9 +201,9 @@ def model_lora_keys_clip(model, key_map={}):
...
@@ -201,9 +201,9 @@ def model_lora_keys_clip(model, key_map={}):
key_map
[
lora_key
]
=
k
key_map
[
lora_key
]
=
k
k
=
"clip_g.text_projection"
k
=
"clip_g.
transformer.
text_projection
.weight
"
if
k
in
sdk
:
if
k
in
sdk
:
key_map
[
"lora_prior_te_text_projection"
]
=
k
#cascade lora
key_map
[
"lora_prior_te_text_projection"
]
=
k
#cascade lora
?
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
# key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
# key_map["lora_te_text_projection"] = k
# key_map["lora_te_text_projection"] = k
...
...
comfy/sd.py
View file @
c2cb8e88
...
@@ -123,10 +123,13 @@ class CLIP:
...
@@ -123,10 +123,13 @@ class CLIP:
return
self
.
tokenizer
.
tokenize_with_weights
(
text
,
return_word_ids
)
return
self
.
tokenizer
.
tokenize_with_weights
(
text
,
return_word_ids
)
def
encode_from_tokens
(
self
,
tokens
,
return_pooled
=
False
):
def
encode_from_tokens
(
self
,
tokens
,
return_pooled
=
False
):
self
.
cond_stage_model
.
reset_clip_options
()
if
self
.
layer_idx
is
not
None
:
if
self
.
layer_idx
is
not
None
:
self
.
cond_stage_model
.
clip_layer
(
self
.
layer_idx
)
self
.
cond_stage_model
.
set_clip_options
({
"layer"
:
self
.
layer_idx
})
else
:
self
.
cond_stage_model
.
reset_clip_layer
()
if
return_pooled
==
"unprojected"
:
self
.
cond_stage_model
.
set_clip_options
({
"projected_pooled"
:
False
})
self
.
load_model
()
self
.
load_model
()
cond
,
pooled
=
self
.
cond_stage_model
.
encode_token_weights
(
tokens
)
cond
,
pooled
=
self
.
cond_stage_model
.
encode_token_weights
(
tokens
)
...
...
comfy/sd1_clip.py
View file @
c2cb8e88
...
@@ -91,11 +91,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -91,11 +91,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self
.
enable_attention_masks
=
enable_attention_masks
self
.
enable_attention_masks
=
enable_attention_masks
self
.
layer_norm_hidden_state
=
layer_norm_hidden_state
self
.
layer_norm_hidden_state
=
layer_norm_hidden_state
self
.
return_projected_pooled
=
True
if
layer
==
"hidden"
:
if
layer
==
"hidden"
:
assert
layer_idx
is
not
None
assert
layer_idx
is
not
None
assert
abs
(
layer_idx
)
<
self
.
num_layers
assert
abs
(
layer_idx
)
<
self
.
num_layers
self
.
clip_layer
(
layer_idx
)
self
.
set_
clip_
options
({
"
layer
"
:
layer_idx
}
)
self
.
layer
_default
=
(
self
.
layer
,
self
.
layer_idx
)
self
.
options
_default
=
(
self
.
layer
,
self
.
layer_idx
,
self
.
return_projected_pooled
)
def
freeze
(
self
):
def
freeze
(
self
):
self
.
transformer
=
self
.
transformer
.
eval
()
self
.
transformer
=
self
.
transformer
.
eval
()
...
@@ -103,16 +105,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -103,16 +105,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
for
param
in
self
.
parameters
():
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
param
.
requires_grad
=
False
def
clip_layer
(
self
,
layer_idx
):
def
set_clip_options
(
self
,
options
):
if
abs
(
layer_idx
)
>
self
.
num_layers
:
layer_idx
=
options
.
get
(
"layer"
,
self
.
layer_idx
)
self
.
return_projected_pooled
=
options
.
get
(
"projected_pooled"
,
self
.
return_projected_pooled
)
if
layer_idx
is
None
or
abs
(
layer_idx
)
>
self
.
num_layers
:
self
.
layer
=
"last"
self
.
layer
=
"last"
else
:
else
:
self
.
layer
=
"hidden"
self
.
layer
=
"hidden"
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
def
reset_clip_layer
(
self
):
def
reset_clip_options
(
self
):
self
.
layer
=
self
.
layer_default
[
0
]
self
.
layer
=
self
.
options_default
[
0
]
self
.
layer_idx
=
self
.
layer_default
[
1
]
self
.
layer_idx
=
self
.
options_default
[
1
]
self
.
return_projected_pooled
=
self
.
options_default
[
2
]
def
set_up_textual_embeddings
(
self
,
tokens
,
current_embeds
):
def
set_up_textual_embeddings
(
self
,
tokens
,
current_embeds
):
out_tokens
=
[]
out_tokens
=
[]
...
@@ -177,10 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -177,10 +182,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else
:
else
:
z
=
outputs
[
1
]
z
=
outputs
[
1
]
if
outputs
[
2
]
is
not
None
:
pooled_output
=
None
pooled_output
=
outputs
[
2
].
float
()
if
len
(
outputs
)
>=
3
:
else
:
if
not
self
.
return_projected_pooled
and
len
(
outputs
)
>=
4
and
outputs
[
3
]
is
not
None
:
pooled_output
=
None
pooled_output
=
outputs
[
3
].
float
()
elif
outputs
[
2
]
is
not
None
:
pooled_output
=
outputs
[
2
].
float
()
return
z
.
float
(),
pooled_output
return
z
.
float
(),
pooled_output
...
@@ -497,11 +504,11 @@ class SD1ClipModel(torch.nn.Module):
...
@@ -497,11 +504,11 @@ class SD1ClipModel(torch.nn.Module):
self
.
clip
=
"clip_{}"
.
format
(
self
.
clip_name
)
self
.
clip
=
"clip_{}"
.
format
(
self
.
clip_name
)
setattr
(
self
,
self
.
clip
,
clip_model
(
device
=
device
,
dtype
=
dtype
,
**
kwargs
))
setattr
(
self
,
self
.
clip
,
clip_model
(
device
=
device
,
dtype
=
dtype
,
**
kwargs
))
def
clip_
layer
(
self
,
layer_idx
):
def
set_
clip_
options
(
self
,
options
):
getattr
(
self
,
self
.
clip
).
clip_
layer
(
layer_idx
)
getattr
(
self
,
self
.
clip
).
set_
clip_
options
(
options
)
def
reset_clip_
layer
(
self
):
def
reset_clip_
options
(
self
):
getattr
(
self
,
self
.
clip
).
reset_clip_
layer
()
getattr
(
self
,
self
.
clip
).
reset_clip_
options
()
def
encode_token_weights
(
self
,
token_weight_pairs
):
def
encode_token_weights
(
self
,
token_weight_pairs
):
token_weight_pairs
=
token_weight_pairs
[
self
.
clip_name
]
token_weight_pairs
=
token_weight_pairs
[
self
.
clip_name
]
...
...
comfy/sdxl_clip.py
View file @
c2cb8e88
...
@@ -40,13 +40,13 @@ class SDXLClipModel(torch.nn.Module):
...
@@ -40,13 +40,13 @@ class SDXLClipModel(torch.nn.Module):
self
.
clip_l
=
sd1_clip
.
SDClipModel
(
layer
=
"hidden"
,
layer_idx
=-
2
,
device
=
device
,
dtype
=
dtype
,
layer_norm_hidden_state
=
False
)
self
.
clip_l
=
sd1_clip
.
SDClipModel
(
layer
=
"hidden"
,
layer_idx
=-
2
,
device
=
device
,
dtype
=
dtype
,
layer_norm_hidden_state
=
False
)
self
.
clip_g
=
SDXLClipG
(
device
=
device
,
dtype
=
dtype
)
self
.
clip_g
=
SDXLClipG
(
device
=
device
,
dtype
=
dtype
)
def
clip_
layer
(
self
,
layer_idx
):
def
set_
clip_
options
(
self
,
options
):
self
.
clip_l
.
clip_
layer
(
layer_idx
)
self
.
clip_l
.
set_
clip_
options
(
options
)
self
.
clip_g
.
clip_
layer
(
layer_idx
)
self
.
clip_g
.
set_
clip_
options
(
options
)
def
reset_clip_
layer
(
self
):
def
reset_clip_
options
(
self
):
self
.
clip_g
.
reset_clip_
layer
()
self
.
clip_g
.
reset_clip_
options
()
self
.
clip_l
.
reset_clip_
layer
()
self
.
clip_l
.
reset_clip_
options
()
def
encode_token_weights
(
self
,
token_weight_pairs
):
def
encode_token_weights
(
self
,
token_weight_pairs
):
token_weight_pairs_g
=
token_weight_pairs
[
"g"
]
token_weight_pairs_g
=
token_weight_pairs
[
"g"
]
...
...
nodes.py
View file @
c2cb8e88
...
@@ -1003,7 +1003,7 @@ class GLIGENTextBoxApply:
...
@@ -1003,7 +1003,7 @@ class GLIGENTextBoxApply:
def
append
(
self
,
conditioning_to
,
clip
,
gligen_textbox_model
,
text
,
width
,
height
,
x
,
y
):
def
append
(
self
,
conditioning_to
,
clip
,
gligen_textbox_model
,
text
,
width
,
height
,
x
,
y
):
c
=
[]
c
=
[]
cond
,
cond_pooled
=
clip
.
encode_from_tokens
(
clip
.
tokenize
(
text
),
return_pooled
=
True
)
cond
,
cond_pooled
=
clip
.
encode_from_tokens
(
clip
.
tokenize
(
text
),
return_pooled
=
"unprojected"
)
for
t
in
conditioning_to
:
for
t
in
conditioning_to
:
n
=
[
t
[
0
],
t
[
1
].
copy
()]
n
=
[
t
[
0
],
t
[
1
].
copy
()]
position_params
=
[(
cond_pooled
,
height
//
8
,
width
//
8
,
y
//
8
,
x
//
8
)]
position_params
=
[(
cond_pooled
,
height
//
8
,
width
//
8
,
y
//
8
,
x
//
8
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment