Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
f8f7568d
Commit
f8f7568d
authored
Jun 25, 2024
by
comfyanonymous
Browse files
Basic SD3 controlnet implementation.
Still missing the node to properly use it.
parent
66aaa140
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
165 additions
and
15 deletions
+165
-15
comfy/cldm/mmdit.py
comfy/cldm/mmdit.py
+91
-0
comfy/controlnet.py
comfy/controlnet.py
+42
-3
comfy/ldm/modules/diffusionmodules/mmdit.py
comfy/ldm/modules/diffusionmodules/mmdit.py
+23
-7
comfy/model_detection.py
comfy/model_detection.py
+7
-4
comfy/utils.py
comfy/utils.py
+2
-1
No files found.
comfy/cldm/mmdit.py
0 → 100644
View file @
f8f7568d
import
torch
from
typing
import
Dict
,
Optional
import
comfy.ldm.modules.diffusionmodules.mmdit
import
comfy.latent_formats
class
ControlNet
(
comfy
.
ldm
.
modules
.
diffusionmodules
.
mmdit
.
MMDiT
):
def
__init__
(
self
,
num_blocks
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
None
,
**
kwargs
,
):
super
().
__init__
(
dtype
=
dtype
,
device
=
device
,
operations
=
operations
,
final_layer
=
False
,
num_blocks
=
num_blocks
,
**
kwargs
)
# controlnet_blocks
self
.
controlnet_blocks
=
torch
.
nn
.
ModuleList
([])
for
_
in
range
(
len
(
self
.
joint_blocks
)):
self
.
controlnet_blocks
.
append
(
operations
.
Linear
(
self
.
hidden_size
,
self
.
hidden_size
,
device
=
device
,
dtype
=
dtype
))
self
.
pos_embed_input
=
comfy
.
ldm
.
modules
.
diffusionmodules
.
mmdit
.
PatchEmbed
(
None
,
self
.
patch_size
,
self
.
in_channels
,
self
.
hidden_size
,
bias
=
True
,
strict_img_size
=
False
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
latent_format
=
comfy
.
latent_formats
.
SD3
()
def
forward
(
self
,
x
:
torch
.
Tensor
,
timesteps
:
torch
.
Tensor
,
y
:
Optional
[
torch
.
Tensor
]
=
None
,
context
:
Optional
[
torch
.
Tensor
]
=
None
,
hint
=
None
,
)
->
torch
.
Tensor
:
#weird sd3 controlnet specific stuff
hint
=
hint
*
self
.
latent_format
.
scale_factor
# self.latent_format.process_in(hint)
y
=
torch
.
zeros_like
(
y
)
if
self
.
context_processor
is
not
None
:
context
=
self
.
context_processor
(
context
)
hw
=
x
.
shape
[
-
2
:]
x
=
self
.
x_embedder
(
x
)
+
self
.
cropped_pos_embed
(
hw
,
device
=
x
.
device
).
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x
+=
self
.
pos_embed_input
(
hint
)
c
=
self
.
t_embedder
(
timesteps
,
dtype
=
x
.
dtype
)
if
y
is
not
None
and
self
.
y_embedder
is
not
None
:
y
=
self
.
y_embedder
(
y
)
c
=
c
+
y
if
context
is
not
None
:
context
=
self
.
context_embedder
(
context
)
if
self
.
register_length
>
0
:
context
=
torch
.
cat
(
(
repeat
(
self
.
register
,
"1 ... -> b ..."
,
b
=
x
.
shape
[
0
]),
default
(
context
,
torch
.
Tensor
([]).
type_as
(
x
)),
),
1
,
)
output
=
[]
blocks
=
len
(
self
.
joint_blocks
)
for
i
in
range
(
blocks
):
context
,
x
=
self
.
joint_blocks
[
i
](
context
,
x
,
c
=
c
,
use_checkpoint
=
self
.
use_checkpoint
,
)
out
=
self
.
controlnet_blocks
[
i
](
x
)
count
=
self
.
depth
//
blocks
if
i
==
blocks
-
1
:
count
-=
1
for
j
in
range
(
count
):
output
.
append
(
out
)
return
{
"output"
:
output
}
comfy/controlnet.py
View file @
f8f7568d
...
...
@@ -11,6 +11,7 @@ import comfy.ops
import
comfy.cldm.cldm
import
comfy.t2i_adapter.adapter
import
comfy.ldm.cascade.controlnet
import
comfy.cldm.mmdit
def
broadcast_image_to
(
tensor
,
target_batch_size
,
batched_number
):
...
...
@@ -94,13 +95,17 @@ class ControlBase:
for
key
in
control
:
control_output
=
control
[
key
]
applied_to
=
set
()
for
i
in
range
(
len
(
control_output
)):
x
=
control_output
[
i
]
if
x
is
not
None
:
if
self
.
global_average_pooling
:
x
=
torch
.
mean
(
x
,
dim
=
(
2
,
3
),
keepdim
=
True
).
repeat
(
1
,
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
*=
self
.
strength
if
x
not
in
applied_to
:
#memory saving strategy, allow shared tensors and only apply strength to shared tensors once
applied_to
.
add
(
x
)
x
*=
self
.
strength
if
x
.
dtype
!=
output_dtype
:
x
=
x
.
to
(
output_dtype
)
...
...
@@ -120,17 +125,18 @@ class ControlBase:
if
o
[
i
].
shape
[
0
]
<
prev_val
.
shape
[
0
]:
o
[
i
]
=
prev_val
+
o
[
i
]
else
:
o
[
i
]
+
=
prev_val
o
[
i
]
=
prev_val
+
o
[
i
]
#TODO: change back to inplace add if shared tensors stop being an issue
return
out
class
ControlNet
(
ControlBase
):
def
__init__
(
self
,
control_model
=
None
,
global_average_pooling
=
False
,
device
=
None
,
load_device
=
None
,
manual_cast_dtype
=
None
):
def
__init__
(
self
,
control_model
=
None
,
global_average_pooling
=
False
,
compression_ratio
=
8
,
device
=
None
,
load_device
=
None
,
manual_cast_dtype
=
None
):
super
().
__init__
(
device
)
self
.
control_model
=
control_model
self
.
load_device
=
load_device
if
control_model
is
not
None
:
self
.
control_model_wrapped
=
comfy
.
model_patcher
.
ModelPatcher
(
self
.
control_model
,
load_device
=
load_device
,
offload_device
=
comfy
.
model_management
.
unet_offload_device
())
self
.
compression_ratio
=
compression_ratio
self
.
global_average_pooling
=
global_average_pooling
self
.
model_sampling_current
=
None
self
.
manual_cast_dtype
=
manual_cast_dtype
...
...
@@ -308,6 +314,37 @@ class ControlLora(ControlNet):
def
inference_memory_requirements
(
self
,
dtype
):
return
comfy
.
utils
.
calculate_parameters
(
self
.
control_weights
)
*
comfy
.
model_management
.
dtype_size
(
dtype
)
+
ControlBase
.
inference_memory_requirements
(
self
,
dtype
)
def
load_controlnet_mmdit
(
sd
):
new_sd
=
comfy
.
model_detection
.
convert_diffusers_mmdit
(
sd
,
""
)
model_config
=
comfy
.
model_detection
.
model_config_from_unet
(
new_sd
,
""
,
True
)
num_blocks
=
comfy
.
model_detection
.
count_blocks
(
new_sd
,
'joint_blocks.{}.'
)
for
k
in
sd
:
new_sd
[
k
]
=
sd
[
k
]
supported_inference_dtypes
=
model_config
.
supported_inference_dtypes
controlnet_config
=
model_config
.
unet_config
unet_dtype
=
comfy
.
model_management
.
unet_dtype
(
supported_dtypes
=
supported_inference_dtypes
)
load_device
=
comfy
.
model_management
.
get_torch_device
()
manual_cast_dtype
=
comfy
.
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
)
if
manual_cast_dtype
is
not
None
:
operations
=
comfy
.
ops
.
manual_cast
else
:
operations
=
comfy
.
ops
.
disable_weight_init
control_model
=
comfy
.
cldm
.
mmdit
.
ControlNet
(
num_blocks
=
num_blocks
,
operations
=
operations
,
device
=
load_device
,
dtype
=
unet_dtype
,
**
controlnet_config
)
missing
,
unexpected
=
control_model
.
load_state_dict
(
new_sd
,
strict
=
False
)
if
len
(
missing
)
>
0
:
logging
.
warning
(
"missing controlnet keys: {}"
.
format
(
missing
))
if
len
(
unexpected
)
>
0
:
logging
.
debug
(
"unexpected controlnet keys: {}"
.
format
(
unexpected
))
control
=
ControlNet
(
control_model
,
compression_ratio
=
1
,
load_device
=
load_device
,
manual_cast_dtype
=
manual_cast_dtype
)
return
control
def
load_controlnet
(
ckpt_path
,
model
=
None
):
controlnet_data
=
comfy
.
utils
.
load_torch_file
(
ckpt_path
,
safe_load
=
True
)
if
"lora_controlnet"
in
controlnet_data
:
...
...
@@ -360,6 +397,8 @@ def load_controlnet(ckpt_path, model=None):
if
len
(
leftover_keys
)
>
0
:
logging
.
warning
(
"leftover keys: {}"
.
format
(
leftover_keys
))
controlnet_data
=
new_sd
elif
"controlnet_blocks.0.weight"
in
controlnet_data
:
#SD3 diffusers format
return
load_controlnet_mmdit
(
controlnet_data
)
pth_key
=
'control_model.zero_convs.0.0.weight'
pth
=
False
...
...
comfy/ldm/modules/diffusionmodules/mmdit.py
View file @
f8f7568d
...
...
@@ -745,6 +745,8 @@ class MMDiT(nn.Module):
qkv_bias
:
bool
=
True
,
context_processor_layers
=
None
,
context_size
=
4096
,
num_blocks
=
None
,
final_layer
=
True
,
dtype
=
None
,
#TODO
device
=
None
,
operations
=
None
,
...
...
@@ -766,7 +768,10 @@ class MMDiT(nn.Module):
# apply magic --> this defines a head_size of 64
self
.
hidden_size
=
64
*
depth
num_heads
=
depth
if
num_blocks
is
None
:
num_blocks
=
depth
self
.
depth
=
depth
self
.
num_heads
=
num_heads
self
.
x_embedder
=
PatchEmbed
(
...
...
@@ -821,7 +826,7 @@ class MMDiT(nn.Module):
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
attn_mode
=
attn_mode
,
pre_only
=
i
==
depth
-
1
,
pre_only
=
(
i
==
num_blocks
-
1
)
and
final_layer
,
rmsnorm
=
rmsnorm
,
scale_mod_only
=
scale_mod_only
,
swiglu
=
swiglu
,
...
...
@@ -830,11 +835,12 @@ class MMDiT(nn.Module):
device
=
device
,
operations
=
operations
)
for
i
in
range
(
depth
)
for
i
in
range
(
num_blocks
)
]
)
self
.
final_layer
=
FinalLayer
(
self
.
hidden_size
,
patch_size
,
self
.
out_channels
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
if
final_layer
:
self
.
final_layer
=
FinalLayer
(
self
.
hidden_size
,
patch_size
,
self
.
out_channels
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
if
compile_core
:
assert
False
...
...
@@ -893,6 +899,7 @@ class MMDiT(nn.Module):
x
:
torch
.
Tensor
,
c_mod
:
torch
.
Tensor
,
context
:
Optional
[
torch
.
Tensor
]
=
None
,
control
=
None
,
)
->
torch
.
Tensor
:
if
self
.
register_length
>
0
:
context
=
torch
.
cat
(
...
...
@@ -905,13 +912,20 @@ class MMDiT(nn.Module):
# context is B, L', D
# x is B, L, D
for
block
in
self
.
joint_blocks
:
context
,
x
=
block
(
blocks
=
len
(
self
.
joint_blocks
)
for
i
in
range
(
blocks
):
context
,
x
=
self
.
joint_blocks
[
i
](
context
,
x
,
c
=
c_mod
,
use_checkpoint
=
self
.
use_checkpoint
,
)
if
control
is
not
None
:
control_o
=
control
.
get
(
"output"
)
if
i
<
len
(
control_o
):
add
=
control_o
[
i
]
if
add
is
not
None
:
x
+=
add
x
=
self
.
final_layer
(
x
,
c_mod
)
# (N, T, patch_size ** 2 * out_channels)
return
x
...
...
@@ -922,6 +936,7 @@ class MMDiT(nn.Module):
t
:
torch
.
Tensor
,
y
:
Optional
[
torch
.
Tensor
]
=
None
,
context
:
Optional
[
torch
.
Tensor
]
=
None
,
control
=
None
,
)
->
torch
.
Tensor
:
"""
Forward pass of DiT.
...
...
@@ -943,7 +958,7 @@ class MMDiT(nn.Module):
if
context
is
not
None
:
context
=
self
.
context_embedder
(
context
)
x
=
self
.
forward_core_with_concat
(
x
,
c
,
context
)
x
=
self
.
forward_core_with_concat
(
x
,
c
,
context
,
control
)
x
=
self
.
unpatchify
(
x
,
hw
=
hw
)
# (N, out_channels, H, W)
return
x
[:,:,:
hw
[
-
2
],:
hw
[
-
1
]]
...
...
@@ -956,7 +971,8 @@ class OpenAISignatureMMDITWrapper(MMDiT):
timesteps
:
torch
.
Tensor
,
context
:
Optional
[
torch
.
Tensor
]
=
None
,
y
:
Optional
[
torch
.
Tensor
]
=
None
,
control
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
return
super
().
forward
(
x
,
timesteps
,
context
=
context
,
y
=
y
)
return
super
().
forward
(
x
,
timesteps
,
context
=
context
,
y
=
y
,
control
=
control
)
comfy/model_detection.py
View file @
f8f7568d
...
...
@@ -41,7 +41,9 @@ def detect_unet_config(state_dict, key_prefix):
unet_config
[
"in_channels"
]
=
state_dict
[
'{}x_embedder.proj.weight'
.
format
(
key_prefix
)].
shape
[
1
]
patch_size
=
state_dict
[
'{}x_embedder.proj.weight'
.
format
(
key_prefix
)].
shape
[
2
]
unet_config
[
"patch_size"
]
=
patch_size
unet_config
[
"out_channels"
]
=
state_dict
[
'{}final_layer.linear.weight'
.
format
(
key_prefix
)].
shape
[
0
]
//
(
patch_size
*
patch_size
)
final_layer
=
'{}final_layer.linear.weight'
.
format
(
key_prefix
)
if
final_layer
in
state_dict
:
unet_config
[
"out_channels"
]
=
state_dict
[
final_layer
].
shape
[
0
]
//
(
patch_size
*
patch_size
)
unet_config
[
"depth"
]
=
state_dict
[
'{}x_embedder.proj.weight'
.
format
(
key_prefix
)].
shape
[
0
]
//
64
unet_config
[
"input_size"
]
=
None
...
...
@@ -435,10 +437,11 @@ def model_config_from_diffusers_unet(state_dict):
return
None
def
convert_diffusers_mmdit
(
state_dict
,
output_prefix
=
""
):
depth
=
count_blocks
(
state_dict
,
'transformer_blocks.{}.'
)
if
depth
>
0
:
num_blocks
=
count_blocks
(
state_dict
,
'transformer_blocks.{}.'
)
if
num_blocks
>
0
:
depth
=
state_dict
[
"pos_embed.proj.weight"
].
shape
[
0
]
//
64
out_sd
=
{}
sd_map
=
comfy
.
utils
.
mmdit_to_diffusers
({
"depth"
:
depth
},
output_prefix
=
output_prefix
)
sd_map
=
comfy
.
utils
.
mmdit_to_diffusers
({
"depth"
:
depth
,
"num_blocks"
:
num_blocks
},
output_prefix
=
output_prefix
)
for
k
in
sd_map
:
weight
=
state_dict
.
get
(
k
,
None
)
if
weight
is
not
None
:
...
...
comfy/utils.py
View file @
f8f7568d
...
...
@@ -298,7 +298,8 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
key_map
=
{}
depth
=
mmdit_config
.
get
(
"depth"
,
0
)
for
i
in
range
(
depth
):
num_blocks
=
mmdit_config
.
get
(
"num_blocks"
,
depth
)
for
i
in
range
(
num_blocks
):
block_from
=
"transformer_blocks.{}"
.
format
(
i
)
block_to
=
"{}joint_blocks.{}"
.
format
(
output_prefix
,
i
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment