Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
58b2364f
Commit
58b2364f
authored
Jul 21, 2023
by
comfyanonymous
Browse files
Properly support SDXL diffusers unet with UNETLoader node.
parent
01150186
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
16 deletions
+23
-16
comfy/sd.py
comfy/sd.py
+4
-1
comfy/utils.py
comfy/utils.py
+19
-15
No files found.
comfy/sd.py
View file @
58b2364f
...
...
@@ -1139,12 +1139,14 @@ def load_unet(unet_path): #load unet in diffusers format
fp16
=
model_management
.
should_use_fp16
(
model_params
=
parameters
)
match
=
{}
match
[
"context_dim"
]
=
sd
[
"down_blocks.
0
.attentions.1.transformer_blocks.0.attn2.to_k.weight"
].
shape
[
1
]
match
[
"context_dim"
]
=
sd
[
"down_blocks.
1
.attentions.1.transformer_blocks.0.attn2.to_k.weight"
].
shape
[
1
]
match
[
"model_channels"
]
=
sd
[
"conv_in.weight"
].
shape
[
0
]
match
[
"in_channels"
]
=
sd
[
"conv_in.weight"
].
shape
[
1
]
match
[
"adm_in_channels"
]
=
None
if
"class_embedding.linear_1.weight"
in
sd
:
match
[
"adm_in_channels"
]
=
sd
[
"class_embedding.linear_1.weight"
].
shape
[
1
]
elif
"add_embedding.linear_1.weight"
in
sd
:
match
[
"adm_in_channels"
]
=
sd
[
"add_embedding.linear_1.weight"
].
shape
[
1
]
SDXL
=
{
'use_checkpoint'
:
False
,
'image_size'
:
32
,
'out_channels'
:
4
,
'use_spatial_transformer'
:
True
,
'legacy'
:
False
,
'num_classes'
:
'sequential'
,
'adm_in_channels'
:
2816
,
'use_fp16'
:
fp16
,
'in_channels'
:
4
,
'model_channels'
:
320
,
...
...
@@ -1198,6 +1200,7 @@ def load_unet(unet_path): #load unet in diffusers format
model
=
model
.
to
(
offload_device
)
model
.
load_model_weights
(
new_sd
,
""
)
return
ModelPatcher
(
model
,
load_device
=
model_management
.
get_torch_device
(),
offload_device
=
offload_device
)
print
(
"ERROR UNSUPPORTED UNET"
,
unet_path
)
def
save_checkpoint
(
output_path
,
model
,
clip
,
vae
,
metadata
=
None
):
try
:
...
...
comfy/utils.py
View file @
58b2364f
...
...
@@ -120,20 +120,24 @@ UNET_MAP_RESNET = {
}
UNET_MAP_BASIC
=
{
"label_emb.0.0.weight"
:
"class_embedding.linear_1.weight"
,
"label_emb.0.0.bias"
:
"class_embedding.linear_1.bias"
,
"label_emb.0.2.weight"
:
"class_embedding.linear_2.weight"
,
"label_emb.0.2.bias"
:
"class_embedding.linear_2.bias"
,
"input_blocks.0.0.weight"
:
"conv_in.weight"
,
"input_blocks.0.0.bias"
:
"conv_in.bias"
,
"out.0.weight"
:
"conv_norm_out.weight"
,
"out.0.bias"
:
"conv_norm_out.bias"
,
"out.2.weight"
:
"conv_out.weight"
,
"out.2.bias"
:
"conv_out.bias"
,
"time_embed.0.weight"
:
"time_embedding.linear_1.weight"
,
"time_embed.0.bias"
:
"time_embedding.linear_1.bias"
,
"time_embed.2.weight"
:
"time_embedding.linear_2.weight"
,
"time_embed.2.bias"
:
"time_embedding.linear_2.bias"
(
"label_emb.0.0.weight"
,
"class_embedding.linear_1.weight"
),
(
"label_emb.0.0.bias"
,
"class_embedding.linear_1.bias"
),
(
"label_emb.0.2.weight"
,
"class_embedding.linear_2.weight"
),
(
"label_emb.0.2.bias"
,
"class_embedding.linear_2.bias"
),
(
"label_emb.0.0.weight"
,
"add_embedding.linear_1.weight"
),
(
"label_emb.0.0.bias"
,
"add_embedding.linear_1.bias"
),
(
"label_emb.0.2.weight"
,
"add_embedding.linear_2.weight"
),
(
"label_emb.0.2.bias"
,
"add_embedding.linear_2.bias"
),
(
"input_blocks.0.0.weight"
,
"conv_in.weight"
),
(
"input_blocks.0.0.bias"
,
"conv_in.bias"
),
(
"out.0.weight"
,
"conv_norm_out.weight"
),
(
"out.0.bias"
,
"conv_norm_out.bias"
),
(
"out.2.weight"
,
"conv_out.weight"
),
(
"out.2.bias"
,
"conv_out.bias"
),
(
"time_embed.0.weight"
,
"time_embedding.linear_1.weight"
),
(
"time_embed.0.bias"
,
"time_embedding.linear_1.bias"
),
(
"time_embed.2.weight"
,
"time_embedding.linear_2.weight"
),
(
"time_embed.2.bias"
,
"time_embedding.linear_2.bias"
)
}
def
unet_to_diffusers
(
unet_config
):
...
...
@@ -208,7 +212,7 @@ def unet_to_diffusers(unet_config):
n
+=
1
for
k
in
UNET_MAP_BASIC
:
diffusers_unet_map
[
UNET_MAP_BASIC
[
k
]]
=
k
diffusers_unet_map
[
k
[
1
]]
=
k
[
0
]
return
diffusers_unet_map
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment