Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
44361f63
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "98edad418ee2ab5de8e7f411dd6e29c6692b471a"
Commit
44361f63
authored
Sep 15, 2023
by
comfyanonymous
Browse files
Support for text encoder models that need attention_mask.
parent
0d8f3764
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
1 deletion
+12
-1
comfy/sd1_clip.py
comfy/sd1_clip.py
+12
-1
No files found.
comfy/sd1_clip.py
View file @
44361f63
...
@@ -71,6 +71,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -71,6 +71,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self
.
empty_tokens
=
[[
49406
]
+
[
49407
]
*
76
]
self
.
empty_tokens
=
[[
49406
]
+
[
49407
]
*
76
]
self
.
text_projection
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
self
.
transformer
.
get_input_embeddings
().
weight
.
shape
[
1
]))
self
.
text_projection
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
self
.
transformer
.
get_input_embeddings
().
weight
.
shape
[
1
]))
self
.
logit_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
4.6055
))
self
.
logit_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
4.6055
))
self
.
enable_attention_masks
=
False
self
.
layer_norm_hidden_state
=
True
self
.
layer_norm_hidden_state
=
True
if
layer
==
"hidden"
:
if
layer
==
"hidden"
:
...
@@ -147,7 +148,17 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
...
@@ -147,7 +148,17 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
precision_scope
=
lambda
a
,
b
:
contextlib
.
nullcontext
(
a
)
precision_scope
=
lambda
a
,
b
:
contextlib
.
nullcontext
(
a
)
with
precision_scope
(
model_management
.
get_autocast_device
(
device
),
torch
.
float32
):
with
precision_scope
(
model_management
.
get_autocast_device
(
device
),
torch
.
float32
):
outputs
=
self
.
transformer
(
input_ids
=
tokens
,
output_hidden_states
=
self
.
layer
==
"hidden"
)
attention_mask
=
None
if
self
.
enable_attention_masks
:
attention_mask
=
torch
.
zeros_like
(
tokens
)
max_token
=
self
.
transformer
.
get_input_embeddings
().
weight
.
shape
[
0
]
-
1
for
x
in
range
(
attention_mask
.
shape
[
0
]):
for
y
in
range
(
attention_mask
.
shape
[
1
]):
attention_mask
[
x
,
y
]
=
1
if
tokens
[
x
,
y
]
==
max_token
:
break
outputs
=
self
.
transformer
(
input_ids
=
tokens
,
attention_mask
=
attention_mask
,
output_hidden_states
=
self
.
layer
==
"hidden"
)
self
.
transformer
.
set_input_embeddings
(
backup_embeds
)
self
.
transformer
.
set_input_embeddings
(
backup_embeds
)
if
self
.
layer
==
"last"
:
if
self
.
layer
==
"last"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment