Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
e60ca692
Commit
e60ca692
authored
Oct 27, 2023
by
comfyanonymous
Browse files
SD1 and SD2 clip and tokenizer code is now more similar to the SDXL one.
parent
6ec3f12c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
30 deletions
+69
-30
comfy/lora.py
comfy/lora.py
+4
-2
comfy/sd1_clip.py
comfy/sd1_clip.py
+39
-2
comfy/sd2_clip.py
comfy/sd2_clip.py
+10
-2
comfy/sdxl_clip.py
comfy/sdxl_clip.py
+7
-22
comfy/supported_models.py
comfy/supported_models.py
+9
-2
No files found.
comfy/lora.py
View file @
e60ca692
...
...
@@ -141,9 +141,9 @@ def model_lora_keys_clip(model, key_map={}):
text_model_lora_key
=
"lora_te_text_model_encoder_layers_{}_{}"
clip_l_present
=
False
for
b
in
range
(
32
):
for
b
in
range
(
32
):
#TODO: clean up
for
c
in
LORA_CLIP_MAP
:
k
=
"transformer.text_model.encoder.layers.{}.{}.weight"
.
format
(
b
,
c
)
k
=
"
clip_h.
transformer.text_model.encoder.layers.{}.{}.weight"
.
format
(
b
,
c
)
if
k
in
sdk
:
lora_key
=
text_model_lora_key
.
format
(
b
,
LORA_CLIP_MAP
[
c
])
key_map
[
lora_key
]
=
k
...
...
@@ -154,6 +154,8 @@ def model_lora_keys_clip(model, key_map={}):
k
=
"clip_l.transformer.text_model.encoder.layers.{}.{}.weight"
.
format
(
b
,
c
)
if
k
in
sdk
:
lora_key
=
text_model_lora_key
.
format
(
b
,
LORA_CLIP_MAP
[
c
])
key_map
[
lora_key
]
=
k
lora_key
=
"lora_te1_text_model_encoder_layers_{}_{}"
.
format
(
b
,
LORA_CLIP_MAP
[
c
])
#SDXL base
key_map
[
lora_key
]
=
k
clip_l_present
=
True
...
...
comfy/sd1_clip.py
View file @
e60ca692
...
...
@@ -35,7 +35,7 @@ class ClipTokenWeightEncoder:
return
z_empty
.
cpu
(),
first_pooled
.
cpu
()
return
torch
.
cat
(
output
,
dim
=-
2
).
cpu
(),
first_pooled
.
cpu
()
class
SD
1
ClipModel
(
torch
.
nn
.
Module
,
ClipTokenWeightEncoder
):
class
SDClipModel
(
torch
.
nn
.
Module
,
ClipTokenWeightEncoder
):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS
=
[
"last"
,
...
...
@@ -342,7 +342,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
embed_out
=
next
(
iter
(
values
))
return
embed_out
class
SD
1
Tokenizer
:
class
SDTokenizer
:
def
__init__
(
self
,
tokenizer_path
=
None
,
max_length
=
77
,
pad_with_end
=
True
,
embedding_directory
=
None
,
embedding_size
=
768
,
embedding_key
=
'clip_l'
):
if
tokenizer_path
is
None
:
tokenizer_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"sd1_tokenizer"
)
...
...
@@ -454,3 +454,40 @@ class SD1Tokenizer:
def
untokenize
(
self
,
token_weight_pair
):
return
list
(
map
(
lambda
a
:
(
a
,
self
.
inv_vocab
[
a
[
0
]]),
token_weight_pair
))
class
SD1Tokenizer
:
def
__init__
(
self
,
embedding_directory
=
None
,
clip_name
=
"l"
,
tokenizer
=
SDTokenizer
):
self
.
clip_name
=
clip_name
self
.
clip
=
"clip_{}"
.
format
(
self
.
clip_name
)
setattr
(
self
,
self
.
clip
,
tokenizer
(
embedding_directory
=
embedding_directory
))
def
tokenize_with_weights
(
self
,
text
:
str
,
return_word_ids
=
False
):
out
=
{}
out
[
self
.
clip_name
]
=
getattr
(
self
,
self
.
clip
).
tokenize_with_weights
(
text
,
return_word_ids
)
return
out
def
untokenize
(
self
,
token_weight_pair
):
return
getattr
(
self
,
self
.
clip
).
untokenize
(
token_weight_pair
)
class
SD1ClipModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
,
clip_name
=
"l"
,
clip_model
=
SDClipModel
):
super
().
__init__
()
self
.
clip_name
=
clip_name
self
.
clip
=
"clip_{}"
.
format
(
self
.
clip_name
)
setattr
(
self
,
self
.
clip
,
clip_model
(
device
=
device
,
dtype
=
dtype
))
def
clip_layer
(
self
,
layer_idx
):
getattr
(
self
,
self
.
clip
).
clip_layer
(
layer_idx
)
def
reset_clip_layer
(
self
):
getattr
(
self
,
self
.
clip
).
reset_clip_layer
()
def
encode_token_weights
(
self
,
token_weight_pairs
):
token_weight_pairs
=
token_weight_pairs
[
self
.
clip_name
]
out
,
pooled
=
getattr
(
self
,
self
.
clip
).
encode_token_weights
(
token_weight_pairs
)
return
out
,
pooled
def
load_sd
(
self
,
sd
):
return
getattr
(
self
,
self
.
clip
).
load_sd
(
sd
)
comfy/sd2_clip.py
View file @
e60ca692
...
...
@@ -2,7 +2,7 @@ from comfy import sd1_clip
import
torch
import
os
class
SD2ClipModel
(
sd1_clip
.
SD
1
ClipModel
):
class
SD2Clip
H
Model
(
sd1_clip
.
SDClipModel
):
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"penultimate"
,
layer_idx
=
None
,
textmodel_path
=
None
,
dtype
=
None
):
if
layer
==
"penultimate"
:
layer
=
"hidden"
...
...
@@ -12,6 +12,14 @@ class SD2ClipModel(sd1_clip.SD1ClipModel):
super
().
__init__
(
device
=
device
,
freeze
=
freeze
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
textmodel_json_config
,
textmodel_path
=
textmodel_path
,
dtype
=
dtype
)
self
.
empty_tokens
=
[[
49406
]
+
[
49407
]
+
[
0
]
*
75
]
class
SD2Tokenizer
(
sd1_clip
.
SD
1
Tokenizer
):
class
SD2
ClipH
Tokenizer
(
sd1_clip
.
SDTokenizer
):
def
__init__
(
self
,
tokenizer_path
=
None
,
embedding_directory
=
None
):
super
().
__init__
(
tokenizer_path
,
pad_with_end
=
False
,
embedding_directory
=
embedding_directory
,
embedding_size
=
1024
)
class
SD2Tokenizer
(
sd1_clip
.
SD1Tokenizer
):
def
__init__
(
self
,
embedding_directory
=
None
):
super
().
__init__
(
embedding_directory
=
embedding_directory
,
clip_name
=
"h"
,
tokenizer
=
SD2ClipHTokenizer
)
class
SD2ClipModel
(
sd1_clip
.
SD1ClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
super
().
__init__
(
device
=
device
,
dtype
=
dtype
,
clip_name
=
"h"
,
clip_model
=
SD2ClipHModel
)
comfy/sdxl_clip.py
View file @
e60ca692
...
...
@@ -2,7 +2,7 @@ from comfy import sd1_clip
import
torch
import
os
class
SDXLClipG
(
sd1_clip
.
SD
1
ClipModel
):
class
SDXLClipG
(
sd1_clip
.
SDClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"penultimate"
,
layer_idx
=
None
,
textmodel_path
=
None
,
dtype
=
None
):
if
layer
==
"penultimate"
:
layer
=
"hidden"
...
...
@@ -16,14 +16,14 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
def
load_sd
(
self
,
sd
):
return
super
().
load_sd
(
sd
)
class
SDXLClipGTokenizer
(
sd1_clip
.
SD
1
Tokenizer
):
class
SDXLClipGTokenizer
(
sd1_clip
.
SDTokenizer
):
def
__init__
(
self
,
tokenizer_path
=
None
,
embedding_directory
=
None
):
super
().
__init__
(
tokenizer_path
,
pad_with_end
=
False
,
embedding_directory
=
embedding_directory
,
embedding_size
=
1280
,
embedding_key
=
'clip_g'
)
class
SDXLTokenizer
(
sd1_clip
.
SD1Tokenizer
)
:
class
SDXLTokenizer
:
def
__init__
(
self
,
embedding_directory
=
None
):
self
.
clip_l
=
sd1_clip
.
SD
1
Tokenizer
(
embedding_directory
=
embedding_directory
)
self
.
clip_l
=
sd1_clip
.
SDTokenizer
(
embedding_directory
=
embedding_directory
)
self
.
clip_g
=
SDXLClipGTokenizer
(
embedding_directory
=
embedding_directory
)
def
tokenize_with_weights
(
self
,
text
:
str
,
return_word_ids
=
False
):
...
...
@@ -38,7 +38,7 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
class
SDXLClipModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
super
().
__init__
()
self
.
clip_l
=
sd1_clip
.
SD
1
ClipModel
(
layer
=
"hidden"
,
layer_idx
=
11
,
device
=
device
,
dtype
=
dtype
)
self
.
clip_l
=
sd1_clip
.
SDClipModel
(
layer
=
"hidden"
,
layer_idx
=
11
,
device
=
device
,
dtype
=
dtype
)
self
.
clip_l
.
layer_norm_hidden_state
=
False
self
.
clip_g
=
SDXLClipG
(
device
=
device
,
dtype
=
dtype
)
...
...
@@ -63,21 +63,6 @@ class SDXLClipModel(torch.nn.Module):
else
:
return
self
.
clip_l
.
load_sd
(
sd
)
class
SDXLRefinerClipModel
(
torch
.
nn
.
Module
):
class
SDXLRefinerClipModel
(
sd1_clip
.
SD1ClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
dtype
=
None
):
super
().
__init__
()
self
.
clip_g
=
SDXLClipG
(
device
=
device
,
dtype
=
dtype
)
def
clip_layer
(
self
,
layer_idx
):
self
.
clip_g
.
clip_layer
(
layer_idx
)
def
reset_clip_layer
(
self
):
self
.
clip_g
.
reset_clip_layer
()
def
encode_token_weights
(
self
,
token_weight_pairs
):
token_weight_pairs_g
=
token_weight_pairs
[
"g"
]
g_out
,
g_pooled
=
self
.
clip_g
.
encode_token_weights
(
token_weight_pairs_g
)
return
g_out
,
g_pooled
def
load_sd
(
self
,
sd
):
return
self
.
clip_g
.
load_sd
(
sd
)
super
().
__init__
(
device
=
device
,
dtype
=
dtype
,
clip_name
=
"g"
,
clip_model
=
SDXLClipG
)
comfy/supported_models.py
View file @
e60ca692
...
...
@@ -38,8 +38,15 @@ class SD15(supported_models_base.BASE):
if
ids
.
dtype
==
torch
.
float32
:
state_dict
[
'cond_stage_model.transformer.text_model.embeddings.position_ids'
]
=
ids
.
round
()
replace_prefix
=
{}
replace_prefix
[
"cond_stage_model."
]
=
"cond_stage_model.clip_l."
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
return
state_dict
def
process_clip_state_dict_for_saving
(
self
,
state_dict
):
replace_prefix
=
{
"clip_l."
:
"cond_stage_model."
}
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
def
clip_target
(
self
):
return
supported_models_base
.
ClipTarget
(
sd1_clip
.
SD1Tokenizer
,
sd1_clip
.
SD1ClipModel
)
...
...
@@ -62,12 +69,12 @@ class SD20(supported_models_base.BASE):
return
model_base
.
ModelType
.
EPS
def
process_clip_state_dict
(
self
,
state_dict
):
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"cond_stage_model.model."
,
"cond_stage_model.transformer.text_model."
,
24
)
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"cond_stage_model.model."
,
"cond_stage_model.
clip_h.
transformer.text_model."
,
24
)
return
state_dict
def
process_clip_state_dict_for_saving
(
self
,
state_dict
):
replace_prefix
=
{}
replace_prefix
[
""
]
=
"cond_stage_model.model
.
"
replace_prefix
[
"
clip_h
"
]
=
"cond_stage_model.model"
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
state_dict
=
diffusers_convert
.
convert_text_enc_state_dict_v20
(
state_dict
)
return
state_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment