Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
742d5720
"comfy/ldm/vscode:/vscode.git/clone" did not exist on "8c6493578b3dda233e9b9a953feeaf1e6ca434ad"
Commit
742d5720
authored
Jun 09, 2024
by
comfyanonymous
Browse files
Support zeroing out text embeddings with the attention mask.
parent
6cd8ffc4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
4 deletions
+9
-4
comfy/sd1_clip.py
comfy/sd1_clip.py
+9
-4
No files found.
comfy/sd1_clip.py
View file @
742d5720
...
...
@@ -68,7 +68,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
]
def
__init__
(
self
,
version
=
"openai/clip-vit-large-patch14"
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"last"
,
layer_idx
=
None
,
textmodel_json_config
=
None
,
dtype
=
None
,
model_class
=
comfy
.
clip_model
.
CLIPTextModel
,
special_tokens
=
{
"start"
:
49406
,
"end"
:
49407
,
"pad"
:
49407
},
layer_norm_hidden_state
=
True
,
enable_attention_masks
=
False
,
return_projected_pooled
=
True
):
# clip-vit-base-patch32
special_tokens
=
{
"start"
:
49406
,
"end"
:
49407
,
"pad"
:
49407
},
layer_norm_hidden_state
=
True
,
enable_attention_masks
=
False
,
zero_out_masked
=
False
,
return_projected_pooled
=
True
):
# clip-vit-base-patch32
super
().
__init__
()
assert
layer
in
self
.
LAYERS
...
...
@@ -90,6 +91,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self
.
logit_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
4.6055
))
self
.
enable_attention_masks
=
enable_attention_masks
self
.
zero_out_masked
=
zero_out_masked
self
.
layer_norm_hidden_state
=
layer_norm_hidden_state
self
.
return_projected_pooled
=
return_projected_pooled
...
...
@@ -179,9 +181,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self
.
transformer
.
set_input_embeddings
(
backup_embeds
)
if
self
.
layer
==
"last"
:
z
=
outputs
[
0
]
z
=
outputs
[
0
]
.
float
()
else
:
z
=
outputs
[
1
]
z
=
outputs
[
1
].
float
()
if
self
.
zero_out_masked
and
attention_mask
is
not
None
:
z
*=
attention_mask
.
unsqueeze
(
-
1
).
float
()
pooled_output
=
None
if
len
(
outputs
)
>=
3
:
...
...
@@ -190,7 +195,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
elif
outputs
[
2
]
is
not
None
:
pooled_output
=
outputs
[
2
].
float
()
return
z
.
float
()
,
pooled_output
return
z
,
pooled_output
def
encode
(
self
,
tokens
):
return
self
(
tokens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment