Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
03e83bb5
Commit
03e83bb5
authored
Mar 06, 2024
by
comfyanonymous
Browse files
Support stable cascade canny controlnet.
parent
10860bcd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
116 additions
and
10 deletions
+116
-10
comfy/controlnet.py
comfy/controlnet.py
+12
-3
comfy/ldm/cascade/controlnet.py
comfy/ldm/cascade/controlnet.py
+94
-0
comfy/ldm/cascade/stage_c.py
comfy/ldm/cascade/stage_c.py
+10
-7
No files found.
comfy/controlnet.py
View file @
03e83bb5
...
...
@@ -9,6 +9,7 @@ import comfy.ops
import
comfy.cldm.cldm
import
comfy.t2i_adapter.adapter
import
comfy.ldm.cascade.controlnet
def
broadcast_image_to
(
tensor
,
target_batch_size
,
batched_number
):
...
...
@@ -78,6 +79,7 @@ class ControlBase:
c
.
strength
=
self
.
strength
c
.
timestep_percent_range
=
self
.
timestep_percent_range
c
.
global_average_pooling
=
self
.
global_average_pooling
c
.
compression_ratio
=
self
.
compression_ratio
def
inference_memory_requirements
(
self
,
dtype
):
if
self
.
previous_controlnet
is
not
None
:
...
...
@@ -433,11 +435,12 @@ def load_controlnet(ckpt_path, model=None):
return
control
class
T2IAdapter
(
ControlBase
):
def
__init__
(
self
,
t2i_model
,
channels_in
,
device
=
None
):
def
__init__
(
self
,
t2i_model
,
channels_in
,
compression_ratio
,
device
=
None
):
super
().
__init__
(
device
)
self
.
t2i_model
=
t2i_model
self
.
channels_in
=
channels_in
self
.
control_input
=
None
self
.
compression_ratio
=
compression_ratio
def
scale_image_to
(
self
,
width
,
height
):
unshuffle_amount
=
self
.
t2i_model
.
unshuffle_amount
...
...
@@ -482,11 +485,13 @@ class T2IAdapter(ControlBase):
return
self
.
control_merge
(
control_input
,
mid
,
control_prev
,
x_noisy
.
dtype
)
def
copy
(
self
):
c
=
T2IAdapter
(
self
.
t2i_model
,
self
.
channels_in
)
c
=
T2IAdapter
(
self
.
t2i_model
,
self
.
channels_in
,
self
.
compression_ratio
)
self
.
copy_to
(
c
)
return
c
def
load_t2i_adapter
(
t2i_data
):
compression_ratio
=
8
if
'adapter'
in
t2i_data
:
t2i_data
=
t2i_data
[
'adapter'
]
if
'adapter.body.0.resnets.0.block1.weight'
in
t2i_data
:
#diffusers format
...
...
@@ -514,8 +519,12 @@ def load_t2i_adapter(t2i_data):
if
cin
==
256
or
cin
==
768
:
xl
=
True
model_ad
=
comfy
.
t2i_adapter
.
adapter
.
Adapter
(
cin
=
cin
,
channels
=
[
channel
,
channel
*
2
,
channel
*
4
,
channel
*
4
][:
4
],
nums_rb
=
2
,
ksize
=
ksize
,
sk
=
True
,
use_conv
=
use_conv
,
xl
=
xl
)
elif
"backbone.0.0.weight"
in
keys
:
model_ad
=
comfy
.
ldm
.
cascade
.
controlnet
.
ControlNet
(
c_in
=
t2i_data
[
'backbone.0.0.weight'
].
shape
[
1
],
proj_blocks
=
[
0
,
4
,
8
,
12
,
51
,
55
,
59
,
63
])
compression_ratio
=
32
else
:
return
None
missing
,
unexpected
=
model_ad
.
load_state_dict
(
t2i_data
)
if
len
(
missing
)
>
0
:
print
(
"t2i missing"
,
missing
)
...
...
@@ -523,4 +532,4 @@ def load_t2i_adapter(t2i_data):
if
len
(
unexpected
)
>
0
:
print
(
"t2i unexpected"
,
unexpected
)
return
T2IAdapter
(
model_ad
,
model_ad
.
input_channels
)
return
T2IAdapter
(
model_ad
,
model_ad
.
input_channels
,
compression_ratio
)
comfy/ldm/cascade/controlnet.py
0 → 100644
View file @
03e83bb5
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import
torch
import
torchvision
from
torch
import
nn
from
.common
import
LayerNorm2d_op
class
CNetResBlock
(
nn
.
Module
):
def
__init__
(
self
,
c
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
self
.
blocks
=
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c
,
dtype
=
dtype
,
device
=
device
),
nn
.
GELU
(),
operations
.
Conv2d
(
c
,
c
,
kernel_size
=
3
,
padding
=
1
),
LayerNorm2d_op
(
operations
)(
c
,
dtype
=
dtype
,
device
=
device
),
nn
.
GELU
(),
operations
.
Conv2d
(
c
,
c
,
kernel_size
=
3
,
padding
=
1
),
)
def
forward
(
self
,
x
):
return
x
+
self
.
blocks
(
x
)
class
ControlNet
(
nn
.
Module
):
def
__init__
(
self
,
c_in
=
3
,
c_proj
=
2048
,
proj_blocks
=
None
,
bottleneck_mode
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
nn
):
super
().
__init__
()
if
bottleneck_mode
is
None
:
bottleneck_mode
=
'effnet'
self
.
proj_blocks
=
proj_blocks
if
bottleneck_mode
==
'effnet'
:
embd_channels
=
1280
self
.
backbone
=
torchvision
.
models
.
efficientnet_v2_s
().
features
.
eval
()
if
c_in
!=
3
:
in_weights
=
self
.
backbone
[
0
][
0
].
weight
.
data
self
.
backbone
[
0
][
0
]
=
operations
.
Conv2d
(
c_in
,
24
,
kernel_size
=
3
,
stride
=
2
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
if
c_in
>
3
:
# nn.init.constant_(self.backbone[0][0].weight, 0)
self
.
backbone
[
0
][
0
].
weight
.
data
[:,
:
3
]
=
in_weights
[:,
:
3
].
clone
()
else
:
self
.
backbone
[
0
][
0
].
weight
.
data
=
in_weights
[:,
:
c_in
].
clone
()
elif
bottleneck_mode
==
'simple'
:
embd_channels
=
c_in
self
.
backbone
=
nn
.
Sequential
(
operations
.
Conv2d
(
embd_channels
,
embd_channels
*
4
,
kernel_size
=
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
),
operations
.
Conv2d
(
embd_channels
*
4
,
embd_channels
,
kernel_size
=
3
,
padding
=
1
,
dtype
=
dtype
,
device
=
device
),
)
elif
bottleneck_mode
==
'large'
:
self
.
backbone
=
nn
.
Sequential
(
operations
.
Conv2d
(
c_in
,
4096
*
4
,
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
),
operations
.
Conv2d
(
4096
*
4
,
1024
,
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
*
[
CNetResBlock
(
1024
)
for
_
in
range
(
8
)],
operations
.
Conv2d
(
1024
,
1280
,
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
)
embd_channels
=
1280
else
:
raise
ValueError
(
f
'Unknown bottleneck mode:
{
bottleneck_mode
}
'
)
self
.
projections
=
nn
.
ModuleList
()
for
_
in
range
(
len
(
proj_blocks
)):
self
.
projections
.
append
(
nn
.
Sequential
(
operations
.
Conv2d
(
embd_channels
,
embd_channels
,
kernel_size
=
1
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
),
nn
.
LeakyReLU
(
0.2
,
inplace
=
True
),
operations
.
Conv2d
(
embd_channels
,
c_proj
,
kernel_size
=
1
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
),
))
# nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
self
.
xl
=
False
self
.
input_channels
=
c_in
self
.
unshuffle_amount
=
8
def
forward
(
self
,
x
):
print
(
x
)
x
=
self
.
backbone
(
x
)
proj_outputs
=
[
None
for
_
in
range
(
max
(
self
.
proj_blocks
)
+
1
)]
for
i
,
idx
in
enumerate
(
self
.
proj_blocks
):
proj_outputs
[
idx
]
=
self
.
projections
[
i
](
x
)
return
proj_outputs
comfy/ldm/cascade/stage_c.py
View file @
03e83bb5
...
...
@@ -194,10 +194,10 @@ class StageC(nn.Module):
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
ResBlock
)):
if
cnet
is
not
None
:
next_cnet
=
cnet
()
next_cnet
=
cnet
.
pop
()
if
next_cnet
is
not
None
:
x
=
x
+
nn
.
functional
.
interpolate
(
next_cnet
,
size
=
x
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
align_corners
=
True
)
.
to
(
x
.
dtype
)
x
=
block
(
x
)
elif
isinstance
(
block
,
AttnBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
...
...
@@ -228,10 +228,10 @@ class StageC(nn.Module):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
skip
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
if
cnet
is
not
None
:
next_cnet
=
cnet
()
next_cnet
=
cnet
.
pop
()
if
next_cnet
is
not
None
:
x
=
x
+
nn
.
functional
.
interpolate
(
next_cnet
,
size
=
x
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
align_corners
=
True
)
.
to
(
x
.
dtype
)
x
=
block
(
x
,
skip
)
elif
isinstance
(
block
,
AttnBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
...
...
@@ -248,7 +248,7 @@ class StageC(nn.Module):
x
=
upscaler
(
x
)
return
x
def
forward
(
self
,
x
,
r
,
clip_text
,
clip_text_pooled
,
clip_img
,
c
net
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
r
,
clip_text
,
clip_text_pooled
,
clip_img
,
c
ontrol
=
None
,
**
kwargs
):
# Process the conditioning embeddings
r_embed
=
self
.
gen_r_embedding
(
r
).
to
(
dtype
=
x
.
dtype
)
for
c
in
self
.
t_conds
:
...
...
@@ -256,10 +256,13 @@ class StageC(nn.Module):
r_embed
=
torch
.
cat
([
r_embed
,
self
.
gen_r_embedding
(
t_cond
).
to
(
dtype
=
x
.
dtype
)],
dim
=
1
)
clip
=
self
.
gen_c_embeddings
(
clip_text
,
clip_text_pooled
,
clip_img
)
if
control
is
not
None
:
cnet
=
control
.
get
(
"input"
)
else
:
cnet
=
None
# Model Blocks
x
=
self
.
embedding
(
x
)
if
cnet
is
not
None
:
cnet
=
ControlNetDeliverer
(
cnet
)
level_outputs
=
self
.
_down_encode
(
x
,
r_embed
,
clip
,
cnet
)
x
=
self
.
_up_decode
(
level_outputs
,
r_embed
,
clip
,
cnet
)
return
self
.
clf
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment