Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
60127a83
Commit
60127a83
authored
Apr 05, 2023
by
sALTaccount
Browse files
diffusers loader
parent
d5cce834
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
382 additions
and
1 deletion
+382
-1
comfy/diffusers_convert.py
comfy/diffusers_convert.py
+364
-0
models/diffusers/put_diffusers_models_here
models/diffusers/put_diffusers_models_here
+0
-0
nodes.py
nodes.py
+18
-1
No files found.
comfy/diffusers_convert.py
0 → 100644
View file @
60127a83
import
json
import
os
import
yaml
# because of local import nonsense
import
sys
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)))
import
folder_paths
from
comfy.ldm.util
import
instantiate_from_config
from
comfy.sd
import
ModelPatcher
,
load_model_weights
,
CLIP
,
VAE
import
os.path
as
osp
import
re
import
torch
from
safetensors.torch
import
load_file
,
save_file
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
# =================#
# UNet Conversion #
# =================#
unet_conversion_map
=
[
# (stable-diffusion, HF Diffusers)
(
"time_embed.0.weight"
,
"time_embedding.linear_1.weight"
),
(
"time_embed.0.bias"
,
"time_embedding.linear_1.bias"
),
(
"time_embed.2.weight"
,
"time_embedding.linear_2.weight"
),
(
"time_embed.2.bias"
,
"time_embedding.linear_2.bias"
),
(
"input_blocks.0.0.weight"
,
"conv_in.weight"
),
(
"input_blocks.0.0.bias"
,
"conv_in.bias"
),
(
"out.0.weight"
,
"conv_norm_out.weight"
),
(
"out.0.bias"
,
"conv_norm_out.bias"
),
(
"out.2.weight"
,
"conv_out.weight"
),
(
"out.2.bias"
,
"conv_out.bias"
),
]
unet_conversion_map_resnet
=
[
# (stable-diffusion, HF Diffusers)
(
"in_layers.0"
,
"norm1"
),
(
"in_layers.2"
,
"conv1"
),
(
"out_layers.0"
,
"norm2"
),
(
"out_layers.3"
,
"conv2"
),
(
"emb_layers.1"
,
"time_emb_proj"
),
(
"skip_connection"
,
"conv_shortcut"
),
]
unet_conversion_map_layer
=
[]
# hardcoded number of downblocks and resnets/attentions...
# would need smarter logic for other networks.
for
i
in
range
(
4
):
# loop over downblocks/upblocks
for
j
in
range
(
2
):
# loop over resnets/attentions for downblocks
hf_down_res_prefix
=
f
"down_blocks.
{
i
}
.resnets.
{
j
}
."
sd_down_res_prefix
=
f
"input_blocks.
{
3
*
i
+
j
+
1
}
.0."
unet_conversion_map_layer
.
append
((
sd_down_res_prefix
,
hf_down_res_prefix
))
if
i
<
3
:
# no attention layers in down_blocks.3
hf_down_atn_prefix
=
f
"down_blocks.
{
i
}
.attentions.
{
j
}
."
sd_down_atn_prefix
=
f
"input_blocks.
{
3
*
i
+
j
+
1
}
.1."
unet_conversion_map_layer
.
append
((
sd_down_atn_prefix
,
hf_down_atn_prefix
))
for
j
in
range
(
3
):
# loop over resnets/attentions for upblocks
hf_up_res_prefix
=
f
"up_blocks.
{
i
}
.resnets.
{
j
}
."
sd_up_res_prefix
=
f
"output_blocks.
{
3
*
i
+
j
}
.0."
unet_conversion_map_layer
.
append
((
sd_up_res_prefix
,
hf_up_res_prefix
))
if
i
>
0
:
# no attention layers in up_blocks.0
hf_up_atn_prefix
=
f
"up_blocks.
{
i
}
.attentions.
{
j
}
."
sd_up_atn_prefix
=
f
"output_blocks.
{
3
*
i
+
j
}
.1."
unet_conversion_map_layer
.
append
((
sd_up_atn_prefix
,
hf_up_atn_prefix
))
if
i
<
3
:
# no downsample in down_blocks.3
hf_downsample_prefix
=
f
"down_blocks.
{
i
}
.downsamplers.0.conv."
sd_downsample_prefix
=
f
"input_blocks.
{
3
*
(
i
+
1
)
}
.0.op."
unet_conversion_map_layer
.
append
((
sd_downsample_prefix
,
hf_downsample_prefix
))
# no upsample in up_blocks.3
hf_upsample_prefix
=
f
"up_blocks.
{
i
}
.upsamplers.0."
sd_upsample_prefix
=
f
"output_blocks.
{
3
*
i
+
2
}
.
{
1
if
i
==
0
else
2
}
."
unet_conversion_map_layer
.
append
((
sd_upsample_prefix
,
hf_upsample_prefix
))
hf_mid_atn_prefix
=
"mid_block.attentions.0."
sd_mid_atn_prefix
=
"middle_block.1."
unet_conversion_map_layer
.
append
((
sd_mid_atn_prefix
,
hf_mid_atn_prefix
))
for
j
in
range
(
2
):
hf_mid_res_prefix
=
f
"mid_block.resnets.
{
j
}
."
sd_mid_res_prefix
=
f
"middle_block.
{
2
*
j
}
."
unet_conversion_map_layer
.
append
((
sd_mid_res_prefix
,
hf_mid_res_prefix
))
def
convert_unet_state_dict
(
unet_state_dict
):
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping
=
{
k
:
k
for
k
in
unet_state_dict
.
keys
()}
for
sd_name
,
hf_name
in
unet_conversion_map
:
mapping
[
hf_name
]
=
sd_name
for
k
,
v
in
mapping
.
items
():
if
"resnets"
in
k
:
for
sd_part
,
hf_part
in
unet_conversion_map_resnet
:
v
=
v
.
replace
(
hf_part
,
sd_part
)
mapping
[
k
]
=
v
for
k
,
v
in
mapping
.
items
():
for
sd_part
,
hf_part
in
unet_conversion_map_layer
:
v
=
v
.
replace
(
hf_part
,
sd_part
)
mapping
[
k
]
=
v
new_state_dict
=
{
v
:
unet_state_dict
[
k
]
for
k
,
v
in
mapping
.
items
()}
return
new_state_dict
# ================#
# VAE Conversion #
# ================#
vae_conversion_map
=
[
# (stable-diffusion, HF Diffusers)
(
"nin_shortcut"
,
"conv_shortcut"
),
(
"norm_out"
,
"conv_norm_out"
),
(
"mid.attn_1."
,
"mid_block.attentions.0."
),
]
for
i
in
range
(
4
):
# down_blocks have two resnets
for
j
in
range
(
2
):
hf_down_prefix
=
f
"encoder.down_blocks.
{
i
}
.resnets.
{
j
}
."
sd_down_prefix
=
f
"encoder.down.
{
i
}
.block.
{
j
}
."
vae_conversion_map
.
append
((
sd_down_prefix
,
hf_down_prefix
))
if
i
<
3
:
hf_downsample_prefix
=
f
"down_blocks.
{
i
}
.downsamplers.0."
sd_downsample_prefix
=
f
"down.
{
i
}
.downsample."
vae_conversion_map
.
append
((
sd_downsample_prefix
,
hf_downsample_prefix
))
hf_upsample_prefix
=
f
"up_blocks.
{
i
}
.upsamplers.0."
sd_upsample_prefix
=
f
"up.
{
3
-
i
}
.upsample."
vae_conversion_map
.
append
((
sd_upsample_prefix
,
hf_upsample_prefix
))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for
j
in
range
(
3
):
hf_up_prefix
=
f
"decoder.up_blocks.
{
i
}
.resnets.
{
j
}
."
sd_up_prefix
=
f
"decoder.up.
{
3
-
i
}
.block.
{
j
}
."
vae_conversion_map
.
append
((
sd_up_prefix
,
hf_up_prefix
))
# this part accounts for mid blocks in both the encoder and the decoder
for
i
in
range
(
2
):
hf_mid_res_prefix
=
f
"mid_block.resnets.
{
i
}
."
sd_mid_res_prefix
=
f
"mid.block_
{
i
+
1
}
."
vae_conversion_map
.
append
((
sd_mid_res_prefix
,
hf_mid_res_prefix
))
vae_conversion_map_attn
=
[
# (stable-diffusion, HF Diffusers)
(
"norm."
,
"group_norm."
),
(
"q."
,
"query."
),
(
"k."
,
"key."
),
(
"v."
,
"value."
),
(
"proj_out."
,
"proj_attn."
),
]
def
reshape_weight_for_sd
(
w
):
# convert HF linear weights to SD conv2d weights
return
w
.
reshape
(
*
w
.
shape
,
1
,
1
)
def
convert_vae_state_dict
(
vae_state_dict
):
mapping
=
{
k
:
k
for
k
in
vae_state_dict
.
keys
()}
for
k
,
v
in
mapping
.
items
():
for
sd_part
,
hf_part
in
vae_conversion_map
:
v
=
v
.
replace
(
hf_part
,
sd_part
)
mapping
[
k
]
=
v
for
k
,
v
in
mapping
.
items
():
if
"attentions"
in
k
:
for
sd_part
,
hf_part
in
vae_conversion_map_attn
:
v
=
v
.
replace
(
hf_part
,
sd_part
)
mapping
[
k
]
=
v
new_state_dict
=
{
v
:
vae_state_dict
[
k
]
for
k
,
v
in
mapping
.
items
()}
weights_to_convert
=
[
"q"
,
"k"
,
"v"
,
"proj_out"
]
for
k
,
v
in
new_state_dict
.
items
():
for
weight_name
in
weights_to_convert
:
if
f
"mid.attn_1.
{
weight_name
}
.weight"
in
k
:
print
(
f
"Reshaping
{
k
}
for SD format"
)
new_state_dict
[
k
]
=
reshape_weight_for_sd
(
v
)
return
new_state_dict
# =========================#
# Text Encoder Conversion #
# =========================#
textenc_conversion_lst
=
[
# (stable-diffusion, HF Diffusers)
(
"resblocks."
,
"text_model.encoder.layers."
),
(
"ln_1"
,
"layer_norm1"
),
(
"ln_2"
,
"layer_norm2"
),
(
".c_fc."
,
".fc1."
),
(
".c_proj."
,
".fc2."
),
(
".attn"
,
".self_attn"
),
(
"ln_final."
,
"transformer.text_model.final_layer_norm."
),
(
"token_embedding.weight"
,
"transformer.text_model.embeddings.token_embedding.weight"
),
(
"positional_embedding"
,
"transformer.text_model.embeddings.position_embedding.weight"
),
]
protected
=
{
re
.
escape
(
x
[
1
]):
x
[
0
]
for
x
in
textenc_conversion_lst
}
textenc_pattern
=
re
.
compile
(
"|"
.
join
(
protected
.
keys
()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
def
convert_text_enc_state_dict_v20
(
text_enc_dict
):
new_state_dict
=
{}
capture_qkv_weight
=
{}
capture_qkv_bias
=
{}
for
k
,
v
in
text_enc_dict
.
items
():
if
(
k
.
endswith
(
".self_attn.q_proj.weight"
)
or
k
.
endswith
(
".self_attn.k_proj.weight"
)
or
k
.
endswith
(
".self_attn.v_proj.weight"
)
):
k_pre
=
k
[:
-
len
(
".q_proj.weight"
)]
k_code
=
k
[
-
len
(
"q_proj.weight"
)]
if
k_pre
not
in
capture_qkv_weight
:
capture_qkv_weight
[
k_pre
]
=
[
None
,
None
,
None
]
capture_qkv_weight
[
k_pre
][
code2idx
[
k_code
]]
=
v
continue
if
(
k
.
endswith
(
".self_attn.q_proj.bias"
)
or
k
.
endswith
(
".self_attn.k_proj.bias"
)
or
k
.
endswith
(
".self_attn.v_proj.bias"
)
):
k_pre
=
k
[:
-
len
(
".q_proj.bias"
)]
k_code
=
k
[
-
len
(
"q_proj.bias"
)]
if
k_pre
not
in
capture_qkv_bias
:
capture_qkv_bias
[
k_pre
]
=
[
None
,
None
,
None
]
capture_qkv_bias
[
k_pre
][
code2idx
[
k_code
]]
=
v
continue
relabelled_key
=
textenc_pattern
.
sub
(
lambda
m
:
protected
[
re
.
escape
(
m
.
group
(
0
))],
k
)
new_state_dict
[
relabelled_key
]
=
v
for
k_pre
,
tensors
in
capture_qkv_weight
.
items
():
if
None
in
tensors
:
raise
Exception
(
"CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing"
)
relabelled_key
=
textenc_pattern
.
sub
(
lambda
m
:
protected
[
re
.
escape
(
m
.
group
(
0
))],
k_pre
)
new_state_dict
[
relabelled_key
+
".in_proj_weight"
]
=
torch
.
cat
(
tensors
)
for
k_pre
,
tensors
in
capture_qkv_bias
.
items
():
if
None
in
tensors
:
raise
Exception
(
"CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing"
)
relabelled_key
=
textenc_pattern
.
sub
(
lambda
m
:
protected
[
re
.
escape
(
m
.
group
(
0
))],
k_pre
)
new_state_dict
[
relabelled_key
+
".in_proj_bias"
]
=
torch
.
cat
(
tensors
)
return
new_state_dict
def
convert_text_enc_state_dict
(
text_enc_dict
):
return
text_enc_dict
def
load_diffusers
(
model_path
,
fp16
=
True
,
output_vae
=
True
,
output_clip
=
True
,
embedding_directory
=
None
):
diffusers_unet_conf
=
json
.
load
(
open
(
osp
.
join
(
model_path
,
"unet/config.json"
)))
diffusers_scheduler_conf
=
json
.
load
(
open
(
osp
.
join
(
model_path
,
"scheduler/scheduler_config.json"
)))
# magic
v2
=
diffusers_unet_conf
[
"sample_size"
]
==
96
v_pred
=
diffusers_scheduler_conf
[
'prediction_type'
]
==
'v_prediction'
if
v2
:
if
v_pred
:
config_path
=
folder_paths
.
get_full_path
(
"configs"
,
'v2-inference-v.yaml'
)
else
:
config_path
=
folder_paths
.
get_full_path
(
"configs"
,
'v2-inference.yaml'
)
else
:
config_path
=
folder_paths
.
get_full_path
(
"configs"
,
'v1-inference.yaml'
)
with
open
(
config_path
,
'r'
)
as
stream
:
config
=
yaml
.
safe_load
(
stream
)
model_config_params
=
config
[
'model'
][
'params'
]
clip_config
=
model_config_params
[
'cond_stage_config'
]
scale_factor
=
model_config_params
[
'scale_factor'
]
vae_config
=
model_config_params
[
'first_stage_config'
]
vae_config
[
'scale_factor'
]
=
scale_factor
unet_path
=
osp
.
join
(
model_path
,
"unet"
,
"diffusion_pytorch_model.safetensors"
)
vae_path
=
osp
.
join
(
model_path
,
"vae"
,
"diffusion_pytorch_model.safetensors"
)
text_enc_path
=
osp
.
join
(
model_path
,
"text_encoder"
,
"model.safetensors"
)
# Load models from safetensors if it exists, if it doesn't pytorch
if
osp
.
exists
(
unet_path
):
unet_state_dict
=
load_file
(
unet_path
,
device
=
"cpu"
)
else
:
unet_path
=
osp
.
join
(
model_path
,
"unet"
,
"diffusion_pytorch_model.bin"
)
unet_state_dict
=
torch
.
load
(
unet_path
,
map_location
=
"cpu"
)
if
osp
.
exists
(
vae_path
):
vae_state_dict
=
load_file
(
vae_path
,
device
=
"cpu"
)
else
:
vae_path
=
osp
.
join
(
model_path
,
"vae"
,
"diffusion_pytorch_model.bin"
)
vae_state_dict
=
torch
.
load
(
vae_path
,
map_location
=
"cpu"
)
if
osp
.
exists
(
text_enc_path
):
text_enc_dict
=
load_file
(
text_enc_path
,
device
=
"cpu"
)
else
:
text_enc_path
=
osp
.
join
(
model_path
,
"text_encoder"
,
"pytorch_model.bin"
)
text_enc_dict
=
torch
.
load
(
text_enc_path
,
map_location
=
"cpu"
)
# Convert the UNet model
unet_state_dict
=
convert_unet_state_dict
(
unet_state_dict
)
unet_state_dict
=
{
"model.diffusion_model."
+
k
:
v
for
k
,
v
in
unet_state_dict
.
items
()}
# Convert the VAE model
vae_state_dict
=
convert_vae_state_dict
(
vae_state_dict
)
vae_state_dict
=
{
"first_stage_model."
+
k
:
v
for
k
,
v
in
vae_state_dict
.
items
()}
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model
=
"text_model.encoder.layers.22.layer_norm2.bias"
in
text_enc_dict
if
is_v20_model
:
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
text_enc_dict
=
{
"transformer."
+
k
:
v
for
k
,
v
in
text_enc_dict
.
items
()}
text_enc_dict
=
convert_text_enc_state_dict_v20
(
text_enc_dict
)
text_enc_dict
=
{
"cond_stage_model.model."
+
k
:
v
for
k
,
v
in
text_enc_dict
.
items
()}
else
:
text_enc_dict
=
convert_text_enc_state_dict
(
text_enc_dict
)
text_enc_dict
=
{
"cond_stage_model.transformer."
+
k
:
v
for
k
,
v
in
text_enc_dict
.
items
()}
# Put together new checkpoint
sd
=
{
**
unet_state_dict
,
**
vae_state_dict
,
**
text_enc_dict
}
clip
=
None
vae
=
None
class
WeightsLoader
(
torch
.
nn
.
Module
):
pass
w
=
WeightsLoader
()
load_state_dict_to
=
[]
if
output_vae
:
vae
=
VAE
(
scale_factor
=
scale_factor
,
config
=
vae_config
)
w
.
first_stage_model
=
vae
.
first_stage_model
load_state_dict_to
=
[
w
]
if
output_clip
:
clip
=
CLIP
(
config
=
clip_config
,
embedding_directory
=
embedding_directory
)
w
.
cond_stage_model
=
clip
.
cond_stage_model
load_state_dict_to
=
[
w
]
model
=
instantiate_from_config
(
config
[
"model"
])
model
=
load_model_weights
(
model
,
sd
,
verbose
=
False
,
load_state_dict_to
=
load_state_dict_to
)
if
fp16
:
model
=
model
.
half
()
return
ModelPatcher
(
model
),
clip
,
vae
models/diffusers/put_diffusers_models_here
0 → 100644
View file @
60127a83
nodes.py
View file @
60127a83
...
@@ -4,13 +4,14 @@ import os
...
@@ -4,13 +4,14 @@ import os
import
sys
import
sys
import
json
import
json
import
hashlib
import
hashlib
import
copy
import
traceback
import
traceback
from
PIL
import
Image
from
PIL
import
Image
from
PIL.PngImagePlugin
import
PngInfo
from
PIL.PngImagePlugin
import
PngInfo
import
numpy
as
np
import
numpy
as
np
from
comfy.diffusers_convert
import
load_diffusers
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"comfy"
))
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"comfy"
))
...
@@ -219,6 +220,21 @@ class CheckpointLoaderSimple:
...
@@ -219,6 +220,21 @@ class CheckpointLoaderSimple:
out
=
comfy
.
sd
.
load_checkpoint_guess_config
(
ckpt_path
,
output_vae
=
True
,
output_clip
=
True
,
embedding_directory
=
folder_paths
.
get_folder_paths
(
"embeddings"
))
out
=
comfy
.
sd
.
load_checkpoint_guess_config
(
ckpt_path
,
output_vae
=
True
,
output_clip
=
True
,
embedding_directory
=
folder_paths
.
get_folder_paths
(
"embeddings"
))
return
out
return
out
class
DiffusersLoader
:
@
classmethod
def
INPUT_TYPES
(
cls
):
return
{
"required"
:
{
"model_path"
:
(
os
.
listdir
(
os
.
path
.
join
(
folder_paths
.
models_dir
,
'diffusers'
),
),),
}}
RETURN_TYPES
=
(
"MODEL"
,
"CLIP"
,
"VAE"
)
FUNCTION
=
"load_checkpoint"
CATEGORY
=
"loaders"
def
load_checkpoint
(
self
,
model_path
,
output_vae
=
True
,
output_clip
=
True
):
model_path
=
os
.
path
.
join
(
folder_paths
.
models_dir
,
'diffusers'
,
model_path
)
return
load_diffusers
(
model_path
,
fp16
=
True
,
output_vae
=
output_vae
,
output_clip
=
output_clip
,
embedding_directory
=
folder_paths
.
get_folder_paths
(
"embeddings"
))
class
unCLIPCheckpointLoader
:
class
unCLIPCheckpointLoader
:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
...
@@ -1076,6 +1092,7 @@ NODE_CLASS_MAPPINGS = {
...
@@ -1076,6 +1092,7 @@ NODE_CLASS_MAPPINGS = {
"TomePatchModel"
:
TomePatchModel
,
"TomePatchModel"
:
TomePatchModel
,
"unCLIPCheckpointLoader"
:
unCLIPCheckpointLoader
,
"unCLIPCheckpointLoader"
:
unCLIPCheckpointLoader
,
"CheckpointLoader"
:
CheckpointLoader
,
"CheckpointLoader"
:
CheckpointLoader
,
"DiffusersLoader"
:
DiffusersLoader
,
}
}
def
load_custom_node
(
module_path
):
def
load_custom_node
(
module_path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment