Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
56d802e1
Commit
56d802e1
authored
Feb 05, 2023
by
comfyanonymous
Browse files
Use transformers CLIP instead of open_clip for SD2.x
This should make things a bit cleaner.
parent
bf9ccffb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
77 additions
and
103 deletions
+77
-103
comfy/sd.py
comfy/sd.py
+44
-29
comfy/sd2_clip.py
comfy/sd2_clip.py
+10
-74
comfy/sd2_clip_config.json
comfy/sd2_clip_config.json
+23
-0
No files found.
comfy/sd.py
View file @
56d802e1
...
@@ -40,6 +40,42 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
...
@@ -40,6 +40,42 @@ def load_model_from_config(config, ckpt, verbose=False, load_state_dict_to=[]):
if
ids
.
dtype
==
torch
.
float32
:
if
ids
.
dtype
==
torch
.
float32
:
sd
[
'cond_stage_model.transformer.text_model.embeddings.position_ids'
]
=
ids
.
round
()
sd
[
'cond_stage_model.transformer.text_model.embeddings.position_ids'
]
=
ids
.
round
()
keys_to_replace
=
{
"cond_stage_model.model.positional_embedding"
:
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight"
,
"cond_stage_model.model.token_embedding.weight"
:
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"
,
"cond_stage_model.model.ln_final.weight"
:
"cond_stage_model.transformer.text_model.final_layer_norm.weight"
,
"cond_stage_model.model.ln_final.bias"
:
"cond_stage_model.transformer.text_model.final_layer_norm.bias"
,
}
for
x
in
keys_to_replace
:
if
x
in
sd
:
sd
[
keys_to_replace
[
x
]]
=
sd
.
pop
(
x
)
resblock_to_replace
=
{
"ln_1"
:
"layer_norm1"
,
"ln_2"
:
"layer_norm2"
,
"mlp.c_fc"
:
"mlp.fc1"
,
"mlp.c_proj"
:
"mlp.fc2"
,
"attn.out_proj"
:
"self_attn.out_proj"
,
}
for
resblock
in
range
(
24
):
for
x
in
resblock_to_replace
:
for
y
in
[
"weight"
,
"bias"
]:
k
=
"cond_stage_model.model.transformer.resblocks.{}.{}.{}"
.
format
(
resblock
,
x
,
y
)
k_to
=
"cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}"
.
format
(
resblock
,
resblock_to_replace
[
x
],
y
)
if
k
in
sd
:
sd
[
k_to
]
=
sd
.
pop
(
k
)
for
y
in
[
"weight"
,
"bias"
]:
k_from
=
"cond_stage_model.model.transformer.resblocks.{}.attn.in_proj_{}"
.
format
(
resblock
,
y
)
if
k_from
in
sd
:
weights
=
sd
.
pop
(
k_from
)
for
x
in
range
(
3
):
p
=
[
"self_attn.q_proj"
,
"self_attn.k_proj"
,
"self_attn.v_proj"
]
k_to
=
"cond_stage_model.transformer.text_model.encoder.layers.{}.{}.{}"
.
format
(
resblock
,
p
[
x
],
y
)
sd
[
k_to
]
=
weights
[
1024
*
x
:
1024
*
(
x
+
1
)]
for
x
in
load_state_dict_to
:
for
x
in
load_state_dict_to
:
x
.
load_state_dict
(
sd
,
strict
=
False
)
x
.
load_state_dict
(
sd
,
strict
=
False
)
...
@@ -62,12 +98,6 @@ LORA_CLIP_MAP = {
...
@@ -62,12 +98,6 @@ LORA_CLIP_MAP = {
"self_attn.out_proj"
:
"self_attn_out_proj"
,
"self_attn.out_proj"
:
"self_attn_out_proj"
,
}
}
LORA_CLIP2_MAP
=
{
"mlp.c_fc"
:
"mlp_fc1"
,
"mlp.c_proj"
:
"mlp_fc2"
,
"attn.out_proj"
:
"self_attn_out_proj"
,
}
LORA_UNET_MAP
=
{
LORA_UNET_MAP
=
{
"proj_in"
:
"proj_in"
,
"proj_in"
:
"proj_in"
,
"proj_out"
:
"proj_out"
,
"proj_out"
:
"proj_out"
,
...
@@ -116,7 +146,7 @@ def model_lora_keys(model, key_map={}):
...
@@ -116,7 +146,7 @@ def model_lora_keys(model, key_map={}):
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
if
k
in
sdk
:
lora_key
=
"lora_unet_down_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
2
,
counter
%
2
,
LORA_UNET_MAP
[
c
])
lora_key
=
"lora_unet_down_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
2
,
counter
%
2
,
LORA_UNET_MAP
[
c
])
key_map
[
lora_key
]
=
(
k
,
0
)
key_map
[
lora_key
]
=
k
up_counter
+=
1
up_counter
+=
1
if
up_counter
>=
4
:
if
up_counter
>=
4
:
counter
+=
1
counter
+=
1
...
@@ -124,7 +154,7 @@ def model_lora_keys(model, key_map={}):
...
@@ -124,7 +154,7 @@ def model_lora_keys(model, key_map={}):
k
=
"model.diffusion_model.middle_block.1.{}.weight"
.
format
(
c
)
k
=
"model.diffusion_model.middle_block.1.{}.weight"
.
format
(
c
)
if
k
in
sdk
:
if
k
in
sdk
:
lora_key
=
"lora_unet_mid_block_attentions_0_{}"
.
format
(
LORA_UNET_MAP
[
c
])
lora_key
=
"lora_unet_mid_block_attentions_0_{}"
.
format
(
LORA_UNET_MAP
[
c
])
key_map
[
lora_key
]
=
(
k
,
0
)
key_map
[
lora_key
]
=
k
counter
=
3
counter
=
3
for
b
in
range
(
12
):
for
b
in
range
(
12
):
tk
=
"model.diffusion_model.output_blocks.{}.1"
.
format
(
b
)
tk
=
"model.diffusion_model.output_blocks.{}.1"
.
format
(
b
)
...
@@ -133,29 +163,18 @@ def model_lora_keys(model, key_map={}):
...
@@ -133,29 +163,18 @@ def model_lora_keys(model, key_map={}):
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
if
k
in
sdk
:
if
k
in
sdk
:
lora_key
=
"lora_unet_up_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
3
,
counter
%
3
,
LORA_UNET_MAP
[
c
])
lora_key
=
"lora_unet_up_blocks_{}_attentions_{}_{}"
.
format
(
counter
//
3
,
counter
%
3
,
LORA_UNET_MAP
[
c
])
key_map
[
lora_key
]
=
(
k
,
0
)
key_map
[
lora_key
]
=
k
up_counter
+=
1
up_counter
+=
1
if
up_counter
>=
4
:
if
up_counter
>=
4
:
counter
+=
1
counter
+=
1
counter
=
0
counter
=
0
text_model_lora_key
=
"lora_te_text_model_encoder_layers_{}_{}"
text_model_lora_key
=
"lora_te_text_model_encoder_layers_{}_{}"
for
b
in
range
(
1
2
):
for
b
in
range
(
2
4
):
for
c
in
LORA_CLIP_MAP
:
for
c
in
LORA_CLIP_MAP
:
k
=
"transformer.text_model.encoder.layers.{}.{}.weight"
.
format
(
b
,
c
)
k
=
"transformer.text_model.encoder.layers.{}.{}.weight"
.
format
(
b
,
c
)
if
k
in
sdk
:
if
k
in
sdk
:
lora_key
=
text_model_lora_key
.
format
(
b
,
LORA_CLIP_MAP
[
c
])
lora_key
=
text_model_lora_key
.
format
(
b
,
LORA_CLIP_MAP
[
c
])
key_map
[
lora_key
]
=
(
k
,
0
)
key_map
[
lora_key
]
=
k
for
b
in
range
(
24
):
for
c
in
LORA_CLIP2_MAP
:
k
=
"model.transformer.resblocks.{}.{}.weight"
.
format
(
b
,
c
)
if
k
in
sdk
:
lora_key
=
text_model_lora_key
.
format
(
b
,
LORA_CLIP2_MAP
[
c
])
key_map
[
lora_key
]
=
(
k
,
0
)
k
=
"model.transformer.resblocks.{}.attn.in_proj_weight"
.
format
(
b
)
if
k
in
sdk
:
key_map
[
text_model_lora_key
.
format
(
b
,
"self_attn_q_proj"
)]
=
(
k
,
0
)
key_map
[
text_model_lora_key
.
format
(
b
,
"self_attn_k_proj"
)]
=
(
k
,
1
)
key_map
[
text_model_lora_key
.
format
(
b
,
"self_attn_v_proj"
)]
=
(
k
,
2
)
return
key_map
return
key_map
...
@@ -174,7 +193,7 @@ class ModelPatcher:
...
@@ -174,7 +193,7 @@ class ModelPatcher:
p
=
{}
p
=
{}
model_sd
=
self
.
model
.
state_dict
()
model_sd
=
self
.
model
.
state_dict
()
for
k
in
patches
:
for
k
in
patches
:
if
k
[
0
]
in
model_sd
:
if
k
in
model_sd
:
p
[
k
]
=
patches
[
k
]
p
[
k
]
=
patches
[
k
]
self
.
patches
+=
[(
strength
,
p
)]
self
.
patches
+=
[(
strength
,
p
)]
return
p
.
keys
()
return
p
.
keys
()
...
@@ -184,8 +203,7 @@ class ModelPatcher:
...
@@ -184,8 +203,7 @@ class ModelPatcher:
for
p
in
self
.
patches
:
for
p
in
self
.
patches
:
for
k
in
p
[
1
]:
for
k
in
p
[
1
]:
v
=
p
[
1
][
k
]
v
=
p
[
1
][
k
]
key
=
k
[
0
]
key
=
k
index
=
k
[
1
]
if
key
not
in
model_sd
:
if
key
not
in
model_sd
:
print
(
"could not patch. key doesn't exist in model:"
,
k
)
print
(
"could not patch. key doesn't exist in model:"
,
k
)
continue
continue
...
@@ -199,10 +217,7 @@ class ModelPatcher:
...
@@ -199,10 +217,7 @@ class ModelPatcher:
mat2
=
v
[
1
]
mat2
=
v
[
1
]
if
v
[
2
]
is
not
None
:
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
calc
=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
()))
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
if
len
(
weight
.
shape
)
>
2
:
calc
=
calc
.
reshape
(
weight
.
shape
)
weight
[
index
*
mat1
.
shape
[
0
]:(
index
+
1
)
*
mat1
.
shape
[
0
]]
+=
calc
.
type
(
weight
.
dtype
).
to
(
weight
.
device
)
return
self
.
model
return
self
.
model
def
unpatch_model
(
self
):
def
unpatch_model
(
self
):
model_sd
=
self
.
model
.
state_dict
()
model_sd
=
self
.
model
.
state_dict
()
...
...
comfy/sd2_clip.py
View file @
56d802e1
import
sd1_clip
import
sd1_clip
import
open_clip
import
torch
import
torch
import
os
class
SD2ClipModel
(
torch
.
nn
.
Module
,
sd1_clip
.
ClipTokenWeightEncoder
):
class
SD2ClipModel
(
sd1_clip
.
SD1ClipModel
):
"""
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"penultimate"
,
layer_idx
=
None
):
Uses the OpenCLIP transformer encoder for text
textmodel_json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"sd2_clip_config.json"
)
"""
super
().
__init__
(
device
=
device
,
freeze
=
freeze
,
textmodel_json_config
=
textmodel_json_config
)
LAYERS
=
[
#"pooled",
"last"
,
"penultimate"
,
"hidden"
]
#version="laion2b_s32b_b79k"
def
__init__
(
self
,
arch
=
"ViT-H-14"
,
device
=
"cpu"
,
max_length
=
77
,
freeze
=
True
,
layer
=
"penultimate"
,
layer_idx
=
None
):
super
().
__init__
()
assert
layer
in
self
.
LAYERS
model
,
_
,
_
=
open_clip
.
create_model_and_transforms
(
arch
,
device
=
torch
.
device
(
'cpu'
))
del
model
.
visual
self
.
model
=
model
self
.
device
=
device
self
.
max_length
=
max_length
self
.
empty_tokens
=
[[
49406
]
+
[
49407
]
+
[
0
]
*
75
]
self
.
empty_tokens
=
[[
49406
]
+
[
49407
]
+
[
0
]
*
75
]
if
freeze
:
if
layer
==
"last"
:
self
.
freeze
()
layer_idx
=
-
1
self
.
layer
=
layer
elif
layer
==
"penultimate"
:
if
self
.
layer
==
"last"
:
layer_idx
=
-
2
self
.
layer_idx
=
0
elif
self
.
layer
==
"penultimate"
:
self
.
layer_idx
=
1
elif
self
.
layer
==
"hidden"
:
elif
self
.
layer
==
"hidden"
:
assert
layer_idx
is
not
None
assert
layer_idx
is
not
None
assert
abs
(
layer_idx
)
<
24
assert
abs
(
layer_idx
)
<
24
self
.
clip_layer
(
layer_idx
)
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
self
.
clip_layer
(
layer_idx
)
def
freeze
(
self
):
self
.
model
=
self
.
model
.
eval
()
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
def
clip_layer
(
self
,
layer_idx
):
#layer_idx should have the same logic as the one for SD1
if
abs
(
layer_idx
)
>=
24
:
self
.
layer_idx
=
0
else
:
if
layer_idx
<
0
:
self
.
layer_idx
=
-
(
layer_idx
+
1
)
else
:
self
.
layer_idx
=
24
-
(
layer_idx
+
1
)
def
forward
(
self
,
tokens
):
tokens
=
torch
.
LongTensor
(
tokens
).
to
(
self
.
device
)
z
=
self
.
encode_with_transformer
(
tokens
)
return
z
def
encode_with_transformer
(
self
,
tokens
):
x
=
self
.
model
.
token_embedding
(
tokens
)
# [batch_size, n_ctx, d_model]
x
=
x
+
self
.
model
.
positional_embedding
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
text_transformer_forward
(
x
,
attn_mask
=
self
.
model
.
attn_mask
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
x
=
self
.
model
.
ln_final
(
x
)
return
x
def
text_transformer_forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
=
None
):
for
i
,
r
in
enumerate
(
self
.
model
.
transformer
.
resblocks
):
if
i
==
len
(
self
.
model
.
transformer
.
resblocks
)
-
self
.
layer_idx
:
break
if
self
.
model
.
transformer
.
grad_checkpointing
and
not
torch
.
jit
.
is_scripting
():
x
=
checkpoint
(
r
,
x
,
attn_mask
)
else
:
x
=
r
(
x
,
attn_mask
=
attn_mask
)
return
x
def
encode
(
self
,
tokens
):
return
self
(
tokens
)
class
SD2Tokenizer
(
sd1_clip
.
SD1Tokenizer
):
class
SD2Tokenizer
(
sd1_clip
.
SD1Tokenizer
):
def
__init__
(
self
,
tokenizer_path
=
None
):
def
__init__
(
self
,
tokenizer_path
=
None
):
...
...
comfy/sd2_clip_config.json
0 → 100644
View file @
56d802e1
{
"architectures"
:
[
"CLIPTextModel"
],
"attention_dropout"
:
0.0
,
"bos_token_id"
:
0
,
"dropout"
:
0.0
,
"eos_token_id"
:
2
,
"hidden_act"
:
"gelu"
,
"hidden_size"
:
1024
,
"initializer_factor"
:
1.0
,
"initializer_range"
:
0.02
,
"intermediate_size"
:
4096
,
"layer_norm_eps"
:
1e-05
,
"max_position_embeddings"
:
77
,
"model_type"
:
"clip_text_model"
,
"num_attention_heads"
:
16
,
"num_hidden_layers"
:
24
,
"pad_token_id"
:
1
,
"projection_dim"
:
512
,
"torch_dtype"
:
"float32"
,
"vocab_size"
:
49408
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment