Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
0d6a5793
Commit
0d6a5793
authored
Jun 19, 2024
by
comfyanonymous
Browse files
Support loading diffusers SD3 model format with UNETLoader node.
parent
b08a9dd0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
84 additions
and
5 deletions
+84
-5
comfy/model_detection.py
comfy/model_detection.py
+37
-0
comfy/sd.py
comfy/sd.py
+8
-1
comfy/utils.py
comfy/utils.py
+16
-4
comfy_extras/nodes_model_merging_model_specific.py
comfy_extras/nodes_model_merging_model_specific.py
+23
-0
No files found.
comfy/model_detection.py
View file @
0d6a5793
import
comfy.supported_models
import
comfy.supported_models
import
comfy.supported_models_base
import
comfy.supported_models_base
import
comfy.utils
import
math
import
math
import
logging
import
logging
import
torch
def
count_blocks
(
state_dict_keys
,
prefix_string
):
def
count_blocks
(
state_dict_keys
,
prefix_string
):
count
=
0
count
=
0
...
@@ -431,3 +433,38 @@ def model_config_from_diffusers_unet(state_dict):
...
@@ -431,3 +433,38 @@ def model_config_from_diffusers_unet(state_dict):
if
unet_config
is
not
None
:
if
unet_config
is
not
None
:
return
model_config_from_unet_config
(
unet_config
)
return
model_config_from_unet_config
(
unet_config
)
return
None
return
None
def
convert_diffusers_mmdit
(
state_dict
,
output_prefix
=
""
):
depth
=
count_blocks
(
state_dict
,
'transformer_blocks.{}.'
)
if
depth
>
0
:
out_sd
=
{}
sd_map
=
comfy
.
utils
.
mmdit_to_diffusers
({
"depth"
:
depth
},
output_prefix
=
output_prefix
)
for
k
in
sd_map
:
weight
=
state_dict
.
get
(
k
,
None
)
if
weight
is
not
None
:
t
=
sd_map
[
k
]
if
not
isinstance
(
t
,
str
):
if
len
(
t
)
>
2
:
fun
=
t
[
2
]
else
:
fun
=
lambda
a
:
a
offset
=
t
[
1
]
if
offset
is
not
None
:
old_weight
=
out_sd
.
get
(
t
[
0
],
None
)
if
old_weight
is
None
:
old_weight
=
torch
.
empty_like
(
weight
)
old_weight
=
old_weight
.
repeat
([
3
]
+
[
1
]
*
(
len
(
old_weight
.
shape
)
-
1
))
w
=
old_weight
.
narrow
(
offset
[
0
],
offset
[
1
],
offset
[
2
])
else
:
old_weight
=
weight
w
=
weight
w
[:]
=
fun
(
weight
)
t
=
t
[
0
]
out_sd
[
t
]
=
old_weight
else
:
out_sd
[
t
]
=
weight
state_dict
.
pop
(
k
)
return
out_sd
comfy/sd.py
View file @
0d6a5793
...
@@ -568,7 +568,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
...
@@ -568,7 +568,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
load_device
=
model_management
.
get_torch_device
()
load_device
=
model_management
.
get_torch_device
()
if
"input_blocks.0.0.weight"
in
sd
or
'clf.1.weight'
in
sd
:
#ldm or stable cascade
if
'transformer_blocks.0.attn.add_q_proj.weight'
in
sd
:
#MMDIT SD3
new_sd
=
model_detection
.
convert_diffusers_mmdit
(
sd
,
""
)
if
new_sd
is
None
:
return
None
model_config
=
model_detection
.
model_config_from_unet
(
new_sd
,
""
)
if
model_config
is
None
:
return
None
elif
"input_blocks.0.0.weight"
in
sd
or
'clf.1.weight'
in
sd
:
#ldm or stable cascade
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
""
)
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
""
)
if
model_config
is
None
:
if
model_config
is
None
:
return
None
return
None
...
...
comfy/utils.py
View file @
0d6a5793
...
@@ -249,6 +249,11 @@ def unet_to_diffusers(unet_config):
...
@@ -249,6 +249,11 @@ def unet_to_diffusers(unet_config):
return
diffusers_unet_map
return
diffusers_unet_map
def
swap_scale_shift
(
weight
):
shift
,
scale
=
weight
.
chunk
(
2
,
dim
=
0
)
new_weight
=
torch
.
cat
([
scale
,
shift
],
dim
=
0
)
return
new_weight
MMDIT_MAP_BASIC
=
{
MMDIT_MAP_BASIC
=
{
(
"context_embedder.bias"
,
"context_embedder.bias"
),
(
"context_embedder.bias"
,
"context_embedder.bias"
),
(
"context_embedder.weight"
,
"context_embedder.weight"
),
(
"context_embedder.weight"
,
"context_embedder.weight"
),
...
@@ -263,8 +268,8 @@ MMDIT_MAP_BASIC = {
...
@@ -263,8 +268,8 @@ MMDIT_MAP_BASIC = {
(
"y_embedder.mlp.2.bias"
,
"time_text_embed.text_embedder.linear_2.bias"
),
(
"y_embedder.mlp.2.bias"
,
"time_text_embed.text_embedder.linear_2.bias"
),
(
"y_embedder.mlp.2.weight"
,
"time_text_embed.text_embedder.linear_2.weight"
),
(
"y_embedder.mlp.2.weight"
,
"time_text_embed.text_embedder.linear_2.weight"
),
(
"pos_embed"
,
"pos_embed.pos_embed"
),
(
"pos_embed"
,
"pos_embed.pos_embed"
),
(
"final_layer.adaLN_modulation.1.bias"
,
"norm_out.linear.bias"
),
(
"final_layer.adaLN_modulation.1.bias"
,
"norm_out.linear.bias"
,
swap_scale_shift
),
(
"final_layer.adaLN_modulation.1.weight"
,
"norm_out.linear.weight"
),
(
"final_layer.adaLN_modulation.1.weight"
,
"norm_out.linear.weight"
,
swap_scale_shift
),
(
"final_layer.linear.bias"
,
"proj_out.bias"
),
(
"final_layer.linear.bias"
,
"proj_out.bias"
),
(
"final_layer.linear.weight"
,
"proj_out.weight"
),
(
"final_layer.linear.weight"
,
"proj_out.weight"
),
}
}
...
@@ -313,8 +318,15 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
...
@@ -313,8 +318,15 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
for
k
in
MMDIT_MAP_BLOCK
:
for
k
in
MMDIT_MAP_BLOCK
:
key_map
[
"{}.{}"
.
format
(
block_from
,
k
[
1
])]
=
"{}.{}"
.
format
(
block_to
,
k
[
0
])
key_map
[
"{}.{}"
.
format
(
block_from
,
k
[
1
])]
=
"{}.{}"
.
format
(
block_to
,
k
[
0
])
for
k
in
MMDIT_MAP_BASIC
:
map_basic
=
MMDIT_MAP_BASIC
.
copy
()
key_map
[
k
[
1
]]
=
"{}{}"
.
format
(
output_prefix
,
k
[
0
])
map_basic
.
add
((
"joint_blocks.{}.context_block.adaLN_modulation.1.bias"
.
format
(
depth
-
1
),
"transformer_blocks.{}.norm1_context.linear.bias"
.
format
(
depth
-
1
),
swap_scale_shift
))
map_basic
.
add
((
"joint_blocks.{}.context_block.adaLN_modulation.1.weight"
.
format
(
depth
-
1
),
"transformer_blocks.{}.norm1_context.linear.weight"
.
format
(
depth
-
1
),
swap_scale_shift
))
for
k
in
map_basic
:
if
len
(
k
)
>
2
:
key_map
[
k
[
1
]]
=
(
"{}{}"
.
format
(
output_prefix
,
k
[
0
]),
None
,
k
[
2
])
else
:
key_map
[
k
[
1
]]
=
"{}{}"
.
format
(
output_prefix
,
k
[
0
])
return
key_map
return
key_map
...
...
comfy_extras/nodes_model_merging_model_specific.py
View file @
0d6a5793
...
@@ -52,9 +52,32 @@ class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
...
@@ -52,9 +52,32 @@ class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return
{
"required"
:
arg_dict
}
return
{
"required"
:
arg_dict
}
class
ModelMergeSD3
(
comfy_extras
.
nodes_model_merging
.
ModelMergeBlocks
):
CATEGORY
=
"advanced/model_merging/model_specific"
@
classmethod
def
INPUT_TYPES
(
s
):
arg_dict
=
{
"model1"
:
(
"MODEL"
,),
"model2"
:
(
"MODEL"
,)}
argument
=
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.01
})
arg_dict
[
"pos_embed."
]
=
argument
arg_dict
[
"x_embedder."
]
=
argument
arg_dict
[
"context_embedder."
]
=
argument
arg_dict
[
"y_embedder."
]
=
argument
arg_dict
[
"t_embedder."
]
=
argument
for
i
in
range
(
38
):
arg_dict
[
"joint_blocks.{}."
.
format
(
i
)]
=
argument
arg_dict
[
"final_layer."
]
=
argument
return
{
"required"
:
arg_dict
}
NODE_CLASS_MAPPINGS
=
{
NODE_CLASS_MAPPINGS
=
{
"ModelMergeSD1"
:
ModelMergeSD1
,
"ModelMergeSD1"
:
ModelMergeSD1
,
"ModelMergeSD2"
:
ModelMergeSD1
,
#SD1 and SD2 have the same blocks
"ModelMergeSD2"
:
ModelMergeSD1
,
#SD1 and SD2 have the same blocks
"ModelMergeSDXL"
:
ModelMergeSDXL
,
"ModelMergeSDXL"
:
ModelMergeSDXL
,
"ModelMergeSD3"
:
ModelMergeSD3
,
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment