Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
e44fa566
"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "1eaae43a14ab00cf518ae664d6d0021441cb3014"
Commit
e44fa566
authored
Jul 10, 2024
by
comfyanonymous
Browse files
Support returning text encoder attention masks.
parent
90389b3b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
5 deletions
+14
-5
comfy/sd1_clip.py
comfy/sd1_clip.py
+14
-5
No files found.
comfy/sd1_clip.py
View file @
e44fa566
...
@@ -38,7 +38,9 @@ class ClipTokenWeightEncoder:
...
@@ -38,7 +38,9 @@ class ClipTokenWeightEncoder:
if
has_weights
or
sections
==
0
:
if
has_weights
or
sections
==
0
:
to_encode
.
append
(
gen_empty_tokens
(
self
.
special_tokens
,
max_token_len
))
to_encode
.
append
(
gen_empty_tokens
(
self
.
special_tokens
,
max_token_len
))
out
,
pooled
=
self
.
encode
(
to_encode
)
o
=
self
.
encode
(
to_encode
)
out
,
pooled
=
o
[:
2
]
if
pooled
is
not
None
:
if
pooled
is
not
None
:
first_pooled
=
pooled
[
0
:
1
].
to
(
model_management
.
intermediate_device
())
first_pooled
=
pooled
[
0
:
1
].
to
(
model_management
.
intermediate_device
())
else
:
else
:
...
@@ -57,8 +59,11 @@ class ClipTokenWeightEncoder:
...
@@ -57,8 +59,11 @@ class ClipTokenWeightEncoder:
output
.
append
(
z
)
output
.
append
(
z
)
if
(
len
(
output
)
==
0
):
if
(
len
(
output
)
==
0
):
return
out
[
-
1
:].
to
(
model_management
.
intermediate_device
()),
first_pooled
r
=
(
out
[
-
1
:].
to
(
model_management
.
intermediate_device
()),
first_pooled
)
return
torch
.
cat
(
output
,
dim
=-
2
).
to
(
model_management
.
intermediate_device
()),
first_pooled
else
:
r
=
(
torch
.
cat
(
output
,
dim
=-
2
).
to
(
model_management
.
intermediate_device
()),
first_pooled
)
r
=
r
+
tuple
(
map
(
lambda
a
:
a
[:
sections
].
flatten
().
unsqueeze
(
dim
=
0
).
to
(
model_management
.
intermediate_device
()),
o
[
2
:]))
return
r
class
SDClipModel
(
torch
.
nn
.
Module
,
ClipTokenWeightEncoder
):
class
SDClipModel
(
torch
.
nn
.
Module
,
ClipTokenWeightEncoder
):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
"""Uses the CLIP transformer encoder for text (from huggingface)"""
...
@@ -70,7 +75,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -70,7 +75,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def
__init__
(
self
,
version
=
"openai/clip-vit-large-patch14"
,
device
=
"cpu"
,
max_length
=
77
,
def
__init__
(
self
,
version
=
"openai/clip-vit-large-patch14"
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
,
layer_idx
=
None
,
textmodel_json_config
=
None
,
dtype
=
None
,
model_class
=
comfy
.
clip_model
.
CLIPTextModel
,
freeze
=
True
,
layer
=
"last"
,
layer_idx
=
None
,
textmodel_json_config
=
None
,
dtype
=
None
,
model_class
=
comfy
.
clip_model
.
CLIPTextModel
,
special_tokens
=
{
"start"
:
49406
,
"end"
:
49407
,
"pad"
:
49407
},
layer_norm_hidden_state
=
True
,
enable_attention_masks
=
False
,
zero_out_masked
=
False
,
special_tokens
=
{
"start"
:
49406
,
"end"
:
49407
,
"pad"
:
49407
},
layer_norm_hidden_state
=
True
,
enable_attention_masks
=
False
,
zero_out_masked
=
False
,
return_projected_pooled
=
True
):
# clip-vit-base-patch32
return_projected_pooled
=
True
,
return_attention_masks
=
False
):
# clip-vit-base-patch32
super
().
__init__
()
super
().
__init__
()
assert
layer
in
self
.
LAYERS
assert
layer
in
self
.
LAYERS
...
@@ -96,6 +101,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -96,6 +101,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self
.
layer_norm_hidden_state
=
layer_norm_hidden_state
self
.
layer_norm_hidden_state
=
layer_norm_hidden_state
self
.
return_projected_pooled
=
return_projected_pooled
self
.
return_projected_pooled
=
return_projected_pooled
self
.
return_attention_masks
=
return_attention_masks
if
layer
==
"hidden"
:
if
layer
==
"hidden"
:
assert
layer_idx
is
not
None
assert
layer_idx
is
not
None
...
@@ -169,7 +175,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -169,7 +175,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens
=
torch
.
LongTensor
(
tokens
).
to
(
device
)
tokens
=
torch
.
LongTensor
(
tokens
).
to
(
device
)
attention_mask
=
None
attention_mask
=
None
if
self
.
enable_attention_masks
or
self
.
zero_out_masked
:
if
self
.
enable_attention_masks
or
self
.
zero_out_masked
or
self
.
return_attention_masks
:
attention_mask
=
torch
.
zeros_like
(
tokens
)
attention_mask
=
torch
.
zeros_like
(
tokens
)
end_token
=
self
.
special_tokens
.
get
(
"end"
,
-
1
)
end_token
=
self
.
special_tokens
.
get
(
"end"
,
-
1
)
for
x
in
range
(
attention_mask
.
shape
[
0
]):
for
x
in
range
(
attention_mask
.
shape
[
0
]):
...
@@ -200,6 +206,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -200,6 +206,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
elif
outputs
[
2
]
is
not
None
:
elif
outputs
[
2
]
is
not
None
:
pooled_output
=
outputs
[
2
].
float
()
pooled_output
=
outputs
[
2
].
float
()
if
self
.
return_attention_masks
:
return
z
,
pooled_output
,
attention_mask
return
z
,
pooled_output
return
z
,
pooled_output
def
encode
(
self
,
tokens
):
def
encode
(
self
,
tokens
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment