Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
04d9bc13
Commit
04d9bc13
authored
Apr 14, 2023
by
comfyanonymous
Browse files
Safely load pickled embeds that don't load with weights_only=True.
parent
334aab05
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
6 deletions
+34
-6
comfy/sd1_clip.py
comfy/sd1_clip.py
+34
-6
No files found.
comfy/sd1_clip.py
View file @
04d9bc13
...
...
@@ -3,6 +3,7 @@ import os
from
transformers
import
CLIPTokenizer
,
CLIPTextModel
,
CLIPTextConfig
import
torch
import
traceback
import
zipfile
class
ClipTokenWeightEncoder
:
def
encode_token_weights
(
self
,
token_weight_pairs
):
...
...
@@ -171,6 +172,26 @@ def unescape_important(text):
text
=
text
.
replace
(
"
\0\2
"
,
"("
)
return
text
def
safe_load_embed_zip
(
embed_path
):
with
zipfile
.
ZipFile
(
embed_path
)
as
myzip
:
names
=
list
(
filter
(
lambda
a
:
"data/"
in
a
,
myzip
.
namelist
()))
names
.
reverse
()
for
n
in
names
:
with
myzip
.
open
(
n
)
as
myfile
:
data
=
myfile
.
read
()
number
=
len
(
data
)
//
4
length_embed
=
1024
#sd2.x
if
number
<
768
:
continue
if
number
%
768
==
0
:
length_embed
=
768
#sd1.x
num_embeds
=
number
//
length_embed
embed
=
torch
.
frombuffer
(
data
,
dtype
=
torch
.
float
)
out
=
embed
.
reshape
((
num_embeds
,
length_embed
)).
clone
()
del
embed
return
out
def
load_embed
(
embedding_name
,
embedding_directory
):
if
isinstance
(
embedding_directory
,
str
):
embedding_directory
=
[
embedding_directory
]
...
...
@@ -195,13 +216,18 @@ def load_embed(embedding_name, embedding_directory):
embed_path
=
valid_file
embed_out
=
None
try
:
if
embed_path
.
lower
().
endswith
(
".safetensors"
):
import
safetensors.torch
embed
=
safetensors
.
torch
.
load_file
(
embed_path
,
device
=
"cpu"
)
else
:
if
'weights_only'
in
torch
.
load
.
__code__
.
co_varnames
:
embed
=
torch
.
load
(
embed_path
,
weights_only
=
True
,
map_location
=
"cpu"
)
try
:
embed
=
torch
.
load
(
embed_path
,
weights_only
=
True
,
map_location
=
"cpu"
)
except
:
embed_out
=
safe_load_embed_zip
(
embed_path
)
else
:
embed
=
torch
.
load
(
embed_path
,
map_location
=
"cpu"
)
except
Exception
as
e
:
...
...
@@ -210,11 +236,13 @@ def load_embed(embedding_name, embedding_directory):
print
(
"error loading embedding, skipping loading:"
,
embedding_name
)
return
None
if
'string_to_param'
in
embed
:
values
=
embed
[
'string_to_param'
].
values
()
else
:
values
=
embed
.
values
()
return
next
(
iter
(
values
))
if
embed_out
is
None
:
if
'string_to_param'
in
embed
:
values
=
embed
[
'string_to_param'
].
values
()
else
:
values
=
embed
.
values
()
embed_out
=
next
(
iter
(
values
))
return
embed_out
class
SD1Tokenizer
:
def
__init__
(
self
,
tokenizer_path
=
None
,
max_length
=
77
,
pad_with_end
=
True
,
embedding_directory
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment