Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
6c6a3925
Commit
6c6a3925
authored
Apr 02, 2024
by
comfyanonymous
Browse files
Fix saving text encoder in fp8.
parent
e6482fbb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
2 deletions
+17
-2
comfy/diffusers_convert.py
comfy/diffusers_convert.py
+17
-2
No files found.
comfy/diffusers_convert.py
View file @
6c6a3925
...
...
@@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
def
cat_tensors
(
tensors
):
x
=
0
for
t
in
tensors
:
x
+=
t
.
shape
[
0
]
shape
=
[
x
]
+
list
(
tensors
[
0
].
shape
)[
1
:]
out
=
torch
.
empty
(
shape
,
device
=
tensors
[
0
].
device
,
dtype
=
tensors
[
0
].
dtype
)
x
=
0
for
t
in
tensors
:
out
[
x
:
x
+
t
.
shape
[
0
]]
=
t
x
+=
t
.
shape
[
0
]
return
out
def
convert_text_enc_state_dict_v20
(
text_enc_dict
,
prefix
=
""
):
new_state_dict
=
{}
...
...
@@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
if
None
in
tensors
:
raise
Exception
(
"CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing"
)
relabelled_key
=
textenc_pattern
.
sub
(
lambda
m
:
protected
[
re
.
escape
(
m
.
group
(
0
))],
k_pre
)
new_state_dict
[
relabelled_key
+
".in_proj_weight"
]
=
torch
.
cat
(
tensors
)
new_state_dict
[
relabelled_key
+
".in_proj_weight"
]
=
cat_tensors
(
tensors
)
for
k_pre
,
tensors
in
capture_qkv_bias
.
items
():
if
None
in
tensors
:
raise
Exception
(
"CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing"
)
relabelled_key
=
textenc_pattern
.
sub
(
lambda
m
:
protected
[
re
.
escape
(
m
.
group
(
0
))],
k_pre
)
new_state_dict
[
relabelled_key
+
".in_proj_bias"
]
=
torch
.
cat
(
tensors
)
new_state_dict
[
relabelled_key
+
".in_proj_bias"
]
=
cat_tensors
(
tensors
)
return
new_state_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment