Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
b3fcd64c
Commit
b3fcd64c
authored
Nov 06, 2023
by
comfyanonymous
Browse files
Make SDTokenizer class work with more types of tokenizers.
parent
a6c83b3c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
10 deletions
+24
-10
comfy/sd1_clip.py
comfy/sd1_clip.py
+24
-10
No files found.
comfy/sd1_clip.py
View file @
b3fcd64c
...
@@ -343,17 +343,24 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
...
@@ -343,17 +343,24 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return
embed_out
return
embed_out
class
SDTokenizer
:
class
SDTokenizer
:
def
__init__
(
self
,
tokenizer_path
=
None
,
max_length
=
77
,
pad_with_end
=
True
,
embedding_directory
=
None
,
embedding_size
=
768
,
embedding_key
=
'clip_l'
):
def
__init__
(
self
,
tokenizer_path
=
None
,
max_length
=
77
,
pad_with_end
=
True
,
embedding_directory
=
None
,
embedding_size
=
768
,
embedding_key
=
'clip_l'
,
tokenizer_class
=
CLIPTokenizer
,
has_start_token
=
True
,
pad_to_max_length
=
True
):
if
tokenizer_path
is
None
:
if
tokenizer_path
is
None
:
tokenizer_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"sd1_tokenizer"
)
tokenizer_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"sd1_tokenizer"
)
self
.
tokenizer
=
CLIPT
okenizer
.
from_pretrained
(
tokenizer_path
)
self
.
tokenizer
=
t
okenizer
_class
.
from_pretrained
(
tokenizer_path
)
self
.
max_length
=
max_length
self
.
max_length
=
max_length
self
.
max_tokens_per_section
=
self
.
max_length
-
2
empty
=
self
.
tokenizer
(
''
)[
"input_ids"
]
empty
=
self
.
tokenizer
(
''
)[
"input_ids"
]
self
.
start_token
=
empty
[
0
]
if
has_start_token
:
self
.
end_token
=
empty
[
1
]
self
.
tokens_start
=
1
self
.
start_token
=
empty
[
0
]
self
.
end_token
=
empty
[
1
]
else
:
self
.
tokens_start
=
0
self
.
start_token
=
None
self
.
end_token
=
empty
[
0
]
self
.
pad_with_end
=
pad_with_end
self
.
pad_with_end
=
pad_with_end
self
.
pad_to_max_length
=
pad_to_max_length
vocab
=
self
.
tokenizer
.
get_vocab
()
vocab
=
self
.
tokenizer
.
get_vocab
()
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
vocab
.
items
()}
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
vocab
.
items
()}
self
.
embedding_directory
=
embedding_directory
self
.
embedding_directory
=
embedding_directory
...
@@ -414,11 +421,13 @@ class SDTokenizer:
...
@@ -414,11 +421,13 @@ class SDTokenizer:
else
:
else
:
continue
continue
#parse word
#parse word
tokens
.
append
([(
t
,
weight
)
for
t
in
self
.
tokenizer
(
word
)[
"input_ids"
][
1
:
-
1
]])
tokens
.
append
([(
t
,
weight
)
for
t
in
self
.
tokenizer
(
word
)[
"input_ids"
][
self
.
tokens_start
:
-
1
]])
#reshape token array to CLIP input size
#reshape token array to CLIP input size
batched_tokens
=
[]
batched_tokens
=
[]
batch
=
[(
self
.
start_token
,
1.0
,
0
)]
batch
=
[]
if
self
.
start_token
is
not
None
:
batch
.
append
((
self
.
start_token
,
1.0
,
0
))
batched_tokens
.
append
(
batch
)
batched_tokens
.
append
(
batch
)
for
i
,
t_group
in
enumerate
(
tokens
):
for
i
,
t_group
in
enumerate
(
tokens
):
#determine if we're going to try and keep the tokens in a single batch
#determine if we're going to try and keep the tokens in a single batch
...
@@ -435,16 +444,21 @@ class SDTokenizer:
...
@@ -435,16 +444,21 @@ class SDTokenizer:
#add end token and pad
#add end token and pad
else
:
else
:
batch
.
append
((
self
.
end_token
,
1.0
,
0
))
batch
.
append
((
self
.
end_token
,
1.0
,
0
))
batch
.
extend
([(
pad_token
,
1.0
,
0
)]
*
(
remaining_length
))
if
self
.
pad_to_max_length
:
batch
.
extend
([(
pad_token
,
1.0
,
0
)]
*
(
remaining_length
))
#start new batch
#start new batch
batch
=
[(
self
.
start_token
,
1.0
,
0
)]
batch
=
[]
if
self
.
start_token
is
not
None
:
batch
.
append
((
self
.
start_token
,
1.0
,
0
))
batched_tokens
.
append
(
batch
)
batched_tokens
.
append
(
batch
)
else
:
else
:
batch
.
extend
([(
t
,
w
,
i
+
1
)
for
t
,
w
in
t_group
])
batch
.
extend
([(
t
,
w
,
i
+
1
)
for
t
,
w
in
t_group
])
t_group
=
[]
t_group
=
[]
#fill last batch
#fill last batch
batch
.
extend
([(
self
.
end_token
,
1.0
,
0
)]
+
[(
pad_token
,
1.0
,
0
)]
*
(
self
.
max_length
-
len
(
batch
)
-
1
))
batch
.
append
((
self
.
end_token
,
1.0
,
0
))
if
self
.
pad_to_max_length
:
batch
.
extend
([(
pad_token
,
1.0
,
0
)]
*
(
self
.
max_length
-
len
(
batch
)))
if
not
return_word_ids
:
if
not
return_word_ids
:
batched_tokens
=
[[(
t
,
w
)
for
t
,
w
,
_
in
x
]
for
x
in
batched_tokens
]
batched_tokens
=
[[(
t
,
w
)
for
t
,
w
,
_
in
x
]
for
x
in
batched_tokens
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment