Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
af7a4991
Commit
af7a4991
authored
Jul 05, 2023
by
comfyanonymous
Browse files
Support loading unet files in diffusers format.
parent
e57cba4c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
123 additions
and
15 deletions
+123
-15
comfy/diffusers_load.py
comfy/diffusers_load.py
+2
-1
comfy/model_detection.py
comfy/model_detection.py
+5
-3
comfy/sd.py
comfy/sd.py
+68
-1
comfy/supported_models.py
comfy/supported_models.py
+4
-4
comfy/supported_models_base.py
comfy/supported_models_base.py
+5
-5
comfy/utils.py
comfy/utils.py
+21
-0
folder_paths.py
folder_paths.py
+1
-0
models/unet/put_unet_files_here
models/unet/put_unet_files_here
+0
-0
nodes.py
nodes.py
+17
-1
No files found.
comfy/diffusers_load.py
View file @
af7a4991
...
...
@@ -8,7 +8,8 @@ import os.path as osp
import
re
import
torch
from
safetensors.torch
import
load_file
,
save_file
import
diffusers_convert
from
.
import
diffusers_convert
def
load_diffusers
(
model_path
,
fp16
=
True
,
output_vae
=
True
,
output_clip
=
True
,
embedding_directory
=
None
):
diffusers_unet_conf
=
json
.
load
(
open
(
osp
.
join
(
model_path
,
"unet/config.json"
)))
...
...
comfy/model_detection.py
View file @
af7a4991
...
...
@@ -108,11 +108,13 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
unet_config
[
"context_dim"
]
=
context_dim
return
unet_config
def
model_config_from_unet
(
state_dict
,
unet_key_prefix
,
use_fp16
):
unet_config
=
detect_unet_config
(
state_dict
,
unet_key_prefix
,
use_fp16
)
def
model_config_from_unet_config
(
unet_config
):
for
model_config
in
supported_models
.
models
:
if
model_config
.
matches
(
unet_config
):
return
model_config
(
unet_config
)
return
None
def
model_config_from_unet
(
state_dict
,
unet_key_prefix
,
use_fp16
):
unet_config
=
detect_unet_config
(
state_dict
,
unet_key_prefix
,
use_fp16
)
return
model_config_from_unet_config
(
unet_config
)
comfy/sd.py
View file @
af7a4991
...
...
@@ -1049,7 +1049,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clipvision
=
clip_vision
.
load_clipvision_from_sd
(
sd
,
model_config
.
clip_vision_prefix
,
True
)
offload_device
=
model_management
.
unet_offload_device
()
model
=
model_config
.
get_model
(
sd
)
model
=
model_config
.
get_model
(
sd
,
"model.diffusion_model."
)
model
=
model
.
to
(
offload_device
)
model
.
load_model_weights
(
sd
,
"model.diffusion_model."
)
...
...
@@ -1073,6 +1073,73 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return
(
ModelPatcher
(
model
,
load_device
=
model_management
.
get_torch_device
(),
offload_device
=
offload_device
),
clip
,
vae
,
clipvision
)
def
load_unet
(
unet_path
):
#load unet in diffusers format
sd
=
utils
.
load_torch_file
(
unet_path
)
parameters
=
calculate_parameters
(
sd
,
""
)
fp16
=
model_management
.
should_use_fp16
(
model_params
=
parameters
)
match
=
{}
match
[
"context_dim"
]
=
sd
[
"down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"
].
shape
[
1
]
match
[
"model_channels"
]
=
sd
[
"conv_in.weight"
].
shape
[
0
]
match
[
"in_channels"
]
=
sd
[
"conv_in.weight"
].
shape
[
1
]
match
[
"adm_in_channels"
]
=
None
if
"class_embedding.linear_1.weight"
in
sd
:
match
[
"adm_in_channels"
]
=
sd
[
"class_embedding.linear_1.weight"
].
shape
[
1
]
SDXL
=
{
'use_checkpoint'
:
False
,
'image_size'
:
32
,
'out_channels'
:
4
,
'use_spatial_transformer'
:
True
,
'legacy'
:
False
,
'num_classes'
:
'sequential'
,
'adm_in_channels'
:
2816
,
'use_fp16'
:
fp16
,
'in_channels'
:
4
,
'model_channels'
:
320
,
'num_res_blocks'
:
2
,
'attention_resolutions'
:
[
2
,
4
],
'transformer_depth'
:
[
0
,
2
,
10
],
'channel_mult'
:
[
1
,
2
,
4
],
'transformer_depth_middle'
:
10
,
'use_linear_in_transformer'
:
True
,
'context_dim'
:
2048
}
SDXL_refiner
=
{
'use_checkpoint'
:
False
,
'image_size'
:
32
,
'out_channels'
:
4
,
'use_spatial_transformer'
:
True
,
'legacy'
:
False
,
'num_classes'
:
'sequential'
,
'adm_in_channels'
:
2560
,
'use_fp16'
:
fp16
,
'in_channels'
:
4
,
'model_channels'
:
384
,
'num_res_blocks'
:
2
,
'attention_resolutions'
:
[
2
,
4
],
'transformer_depth'
:
[
0
,
4
,
4
,
0
],
'channel_mult'
:
[
1
,
2
,
4
,
4
],
'transformer_depth_middle'
:
4
,
'use_linear_in_transformer'
:
True
,
'context_dim'
:
1280
}
SD21
=
{
'use_checkpoint'
:
False
,
'image_size'
:
32
,
'out_channels'
:
4
,
'use_spatial_transformer'
:
True
,
'legacy'
:
False
,
'adm_in_channels'
:
None
,
'use_fp16'
:
fp16
,
'in_channels'
:
4
,
'model_channels'
:
320
,
'num_res_blocks'
:
2
,
'attention_resolutions'
:
[
1
,
2
,
4
],
'transformer_depth'
:
[
1
,
1
,
1
,
0
],
'channel_mult'
:
[
1
,
2
,
4
,
4
],
'transformer_depth_middle'
:
1
,
'use_linear_in_transformer'
:
True
,
'context_dim'
:
1024
}
SD21_uncliph
=
{
'use_checkpoint'
:
False
,
'image_size'
:
32
,
'out_channels'
:
4
,
'use_spatial_transformer'
:
True
,
'legacy'
:
False
,
'num_classes'
:
'sequential'
,
'adm_in_channels'
:
2048
,
'use_fp16'
:
True
,
'in_channels'
:
4
,
'model_channels'
:
320
,
'num_res_blocks'
:
2
,
'attention_resolutions'
:
[
1
,
2
,
4
],
'transformer_depth'
:
[
1
,
1
,
1
,
0
],
'channel_mult'
:
[
1
,
2
,
4
,
4
],
'transformer_depth_middle'
:
1
,
'use_linear_in_transformer'
:
True
,
'context_dim'
:
1024
}
SD21_unclipl
=
{
'use_checkpoint'
:
False
,
'image_size'
:
32
,
'out_channels'
:
4
,
'use_spatial_transformer'
:
True
,
'legacy'
:
False
,
'num_classes'
:
'sequential'
,
'adm_in_channels'
:
1536
,
'use_fp16'
:
True
,
'in_channels'
:
4
,
'model_channels'
:
320
,
'num_res_blocks'
:
2
,
'attention_resolutions'
:
[
1
,
2
,
4
],
'transformer_depth'
:
[
1
,
1
,
1
,
0
],
'channel_mult'
:
[
1
,
2
,
4
,
4
],
'transformer_depth_middle'
:
1
,
'use_linear_in_transformer'
:
True
,
'context_dim'
:
1024
}
SD15
=
{
'use_checkpoint'
:
False
,
'image_size'
:
32
,
'out_channels'
:
4
,
'use_spatial_transformer'
:
True
,
'legacy'
:
False
,
'adm_in_channels'
:
None
,
'use_fp16'
:
True
,
'in_channels'
:
4
,
'model_channels'
:
320
,
'num_res_blocks'
:
2
,
'attention_resolutions'
:
[
1
,
2
,
4
],
'transformer_depth'
:
[
1
,
1
,
1
,
0
],
'channel_mult'
:
[
1
,
2
,
4
,
4
],
'transformer_depth_middle'
:
1
,
'use_linear_in_transformer'
:
False
,
'context_dim'
:
768
}
supported_models
=
[
SDXL
,
SDXL_refiner
,
SD21
,
SD15
,
SD21_uncliph
,
SD21_unclipl
]
print
(
"match"
,
match
)
for
unet_config
in
supported_models
:
matches
=
True
for
k
in
match
:
if
match
[
k
]
!=
unet_config
[
k
]:
matches
=
False
break
if
matches
:
diffusers_keys
=
utils
.
unet_to_diffusers
(
unet_config
)
new_sd
=
{}
for
k
in
diffusers_keys
:
if
k
in
sd
:
new_sd
[
diffusers_keys
[
k
]]
=
sd
.
pop
(
k
)
else
:
print
(
diffusers_keys
[
k
],
k
)
offload_device
=
model_management
.
unet_offload_device
()
model_config
=
model_detection
.
model_config_from_unet_config
(
unet_config
)
model
=
model_config
.
get_model
(
new_sd
,
""
)
model
=
model
.
to
(
offload_device
)
model
.
load_model_weights
(
new_sd
,
""
)
return
ModelPatcher
(
model
,
load_device
=
model_management
.
get_torch_device
(),
offload_device
=
offload_device
)
def
save_checkpoint
(
output_path
,
model
,
clip
,
vae
,
metadata
=
None
):
try
:
model
.
patch_model
()
...
...
comfy/supported_models.py
View file @
af7a4991
...
...
@@ -53,9 +53,9 @@ class SD20(supported_models_base.BASE):
latent_format
=
latent_formats
.
SD15
def
v_prediction
(
self
,
state_dict
):
def
v_prediction
(
self
,
state_dict
,
prefix
=
""
):
if
self
.
unet_config
[
"in_channels"
]
==
4
:
#SD2.0 inpainting models are not v prediction
k
=
"
model.diffusion_model.
output_blocks.11.1.transformer_blocks.0.norm1.bias"
k
=
"
{}
output_blocks.11.1.transformer_blocks.0.norm1.bias"
.
format
(
prefix
)
out
=
state_dict
[
k
]
if
torch
.
std
(
out
,
unbiased
=
False
)
>
0.09
:
# not sure how well this will actually work. I guess we will find out.
return
True
...
...
@@ -109,7 +109,7 @@ class SDXLRefiner(supported_models_base.BASE):
latent_format
=
latent_formats
.
SDXL
def
get_model
(
self
,
state_dict
):
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
return
model_base
.
SDXLRefiner
(
self
)
def
process_clip_state_dict
(
self
,
state_dict
):
...
...
@@ -144,7 +144,7 @@ class SDXL(supported_models_base.BASE):
latent_format
=
latent_formats
.
SDXL
def
get_model
(
self
,
state_dict
):
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
return
model_base
.
SDXL
(
self
)
def
process_clip_state_dict
(
self
,
state_dict
):
...
...
comfy/supported_models_base.py
View file @
af7a4991
...
...
@@ -41,7 +41,7 @@ class BASE:
return
False
return
True
def
v_prediction
(
self
,
state_dict
):
def
v_prediction
(
self
,
state_dict
,
prefix
=
""
):
return
False
def
inpaint_model
(
self
):
...
...
@@ -53,13 +53,13 @@ class BASE:
for
x
in
self
.
unet_extra_config
:
self
.
unet_config
[
x
]
=
self
.
unet_extra_config
[
x
]
def
get_model
(
self
,
state_dict
):
def
get_model
(
self
,
state_dict
,
prefix
=
""
):
if
self
.
inpaint_model
():
return
model_base
.
SDInpaint
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
return
model_base
.
SDInpaint
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
,
prefix
))
elif
self
.
noise_aug_config
is
not
None
:
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
,
prefix
))
else
:
return
model_base
.
BaseModel
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
return
model_base
.
BaseModel
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
,
prefix
))
def
process_clip_state_dict
(
self
,
state_dict
):
return
state_dict
...
...
comfy/utils.py
View file @
af7a4991
...
...
@@ -117,6 +117,23 @@ UNET_MAP_RESNET = {
"out_layers.0.bias"
:
"norm2.bias"
,
}
UNET_MAP_BASIC
=
{
"label_emb.0.0.weight"
:
"class_embedding.linear_1.weight"
,
"label_emb.0.0.bias"
:
"class_embedding.linear_1.bias"
,
"label_emb.0.2.weight"
:
"class_embedding.linear_2.weight"
,
"label_emb.0.2.bias"
:
"class_embedding.linear_2.bias"
,
"input_blocks.0.0.weight"
:
"conv_in.weight"
,
"input_blocks.0.0.bias"
:
"conv_in.bias"
,
"out.0.weight"
:
"conv_norm_out.weight"
,
"out.0.bias"
:
"conv_norm_out.bias"
,
"out.2.weight"
:
"conv_out.weight"
,
"out.2.bias"
:
"conv_out.bias"
,
"time_embed.0.weight"
:
"time_embedding.linear_1.weight"
,
"time_embed.0.bias"
:
"time_embedding.linear_1.bias"
,
"time_embed.2.weight"
:
"time_embedding.linear_2.weight"
,
"time_embed.2.bias"
:
"time_embedding.linear_2.bias"
}
def
unet_to_diffusers
(
unet_config
):
num_res_blocks
=
unet_config
[
"num_res_blocks"
]
attention_resolutions
=
unet_config
[
"attention_resolutions"
]
...
...
@@ -185,6 +202,10 @@ def unet_to_diffusers(unet_config):
for
k
in
[
"weight"
,
"bias"
]:
diffusers_unet_map
[
"up_blocks.{}.upsamplers.0.conv.{}"
.
format
(
x
,
k
)]
=
"output_blocks.{}.{}.conv.{}"
.
format
(
n
,
c
,
k
)
n
+=
1
for
k
in
UNET_MAP_BASIC
:
diffusers_unet_map
[
UNET_MAP_BASIC
[
k
]]
=
k
return
diffusers_unet_map
def
convert_sd_to
(
state_dict
,
dtype
):
...
...
folder_paths.py
View file @
af7a4991
...
...
@@ -14,6 +14,7 @@ folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".y
folder_names_and_paths
[
"loras"
]
=
([
os
.
path
.
join
(
models_dir
,
"loras"
)],
supported_pt_extensions
)
folder_names_and_paths
[
"vae"
]
=
([
os
.
path
.
join
(
models_dir
,
"vae"
)],
supported_pt_extensions
)
folder_names_and_paths
[
"clip"
]
=
([
os
.
path
.
join
(
models_dir
,
"clip"
)],
supported_pt_extensions
)
folder_names_and_paths
[
"unet"
]
=
([
os
.
path
.
join
(
models_dir
,
"unet"
)],
supported_pt_extensions
)
folder_names_and_paths
[
"clip_vision"
]
=
([
os
.
path
.
join
(
models_dir
,
"clip_vision"
)],
supported_pt_extensions
)
folder_names_and_paths
[
"style_models"
]
=
([
os
.
path
.
join
(
models_dir
,
"style_models"
)],
supported_pt_extensions
)
folder_names_and_paths
[
"embeddings"
]
=
([
os
.
path
.
join
(
models_dir
,
"embeddings"
)],
supported_pt_extensions
)
...
...
models/unet/put_unet_files_here
0 → 100644
View file @
af7a4991
nodes.py
View file @
af7a4991
...
...
@@ -397,7 +397,7 @@ class DiffusersLoader:
RETURN_TYPES
=
(
"MODEL"
,
"CLIP"
,
"VAE"
)
FUNCTION
=
"load_checkpoint"
CATEGORY
=
"advanced/loaders"
CATEGORY
=
"advanced/loaders
/deprecated
"
def
load_checkpoint
(
self
,
model_path
,
output_vae
=
True
,
output_clip
=
True
):
for
search_path
in
folder_paths
.
get_folder_paths
(
"diffusers"
):
...
...
@@ -552,6 +552,21 @@ class ControlNetApply:
c
.
append
(
n
)
return
(
c
,
)
class
UNETLoader
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"unet_name"
:
(
folder_paths
.
get_filename_list
(
"unet"
),
),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"load_unet"
CATEGORY
=
"advanced/loaders"
def
load_unet
(
self
,
unet_name
):
unet_path
=
folder_paths
.
get_full_path
(
"unet"
,
unet_name
)
model
=
comfy
.
sd
.
load_unet
(
unet_path
)
return
(
model
,)
class
CLIPLoader
:
@
classmethod
def
INPUT_TYPES
(
s
):
...
...
@@ -1371,6 +1386,7 @@ NODE_CLASS_MAPPINGS = {
"LatentCrop"
:
LatentCrop
,
"LoraLoader"
:
LoraLoader
,
"CLIPLoader"
:
CLIPLoader
,
"UNETLoader"
:
UNETLoader
,
"DualCLIPLoader"
:
DualCLIPLoader
,
"CLIPVisionEncode"
:
CLIPVisionEncode
,
"StyleModelApply"
:
StyleModelApply
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment