Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
cf5ae469
"deploy/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "55d7190255c1fcc8997ac9b2a5feedbaf6754817"
Commit
cf5ae469
authored
Aug 21, 2023
by
comfyanonymous
Browse files
Controlnet/t2iadapter cleanup.
parent
763b0cf0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
55 deletions
+58
-55
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+3
-1
comfy/sd.py
comfy/sd.py
+51
-54
comfy/t2i_adapter/adapter.py
comfy/t2i_adapter/adapter.py
+4
-0
No files found.
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
cf5ae469
...
...
@@ -632,7 +632,9 @@ class UNetModel(nn.Module):
transformer_options
[
"block"
]
=
(
"middle"
,
0
)
h
=
forward_timestep_embed
(
self
.
middle_block
,
h
,
emb
,
context
,
transformer_options
)
if
control
is
not
None
and
'middle'
in
control
and
len
(
control
[
'middle'
])
>
0
:
h
+=
control
[
'middle'
].
pop
()
ctrl
=
control
[
'middle'
].
pop
()
if
ctrl
is
not
None
:
h
+=
ctrl
for
id
,
module
in
enumerate
(
self
.
output_blocks
):
transformer_options
[
"block"
]
=
(
"output"
,
id
)
...
...
comfy/sd.py
View file @
cf5ae469
...
...
@@ -742,6 +742,7 @@ class ControlBase:
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
self
.
previous_controlnet
=
None
self
.
global_average_pooling
=
False
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
,
timestep_percent_range
=
(
1.0
,
0.0
)):
self
.
cond_hint_original
=
cond_hint
...
...
@@ -777,6 +778,51 @@ class ControlBase:
c
.
strength
=
self
.
strength
c
.
timestep_percent_range
=
self
.
timestep_percent_range
def
control_merge
(
self
,
control_input
,
control_output
,
control_prev
,
output_dtype
):
out
=
{
'input'
:[],
'middle'
:[],
'output'
:
[]}
if
control_input
is
not
None
:
for
i
in
range
(
len
(
control_input
)):
key
=
'input'
x
=
control_input
[
i
]
if
x
is
not
None
:
x
*=
self
.
strength
if
x
.
dtype
!=
output_dtype
:
x
=
x
.
to
(
output_dtype
)
out
[
key
].
insert
(
0
,
x
)
if
control_output
is
not
None
:
for
i
in
range
(
len
(
control_output
)):
if
i
==
(
len
(
control_output
)
-
1
):
key
=
'middle'
index
=
0
else
:
key
=
'output'
index
=
i
x
=
control_output
[
i
]
if
x
is
not
None
:
if
self
.
global_average_pooling
:
x
=
torch
.
mean
(
x
,
dim
=
(
2
,
3
),
keepdim
=
True
).
repeat
(
1
,
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
*=
self
.
strength
if
x
.
dtype
!=
output_dtype
:
x
=
x
.
to
(
output_dtype
)
out
[
key
].
append
(
x
)
if
control_prev
is
not
None
:
for
x
in
[
'input'
,
'middle'
,
'output'
]:
o
=
out
[
x
]
for
i
in
range
(
len
(
control_prev
[
x
])):
prev_val
=
control_prev
[
x
][
i
]
if
i
>=
len
(
o
):
o
.
append
(
prev_val
)
elif
prev_val
is
not
None
:
if
o
[
i
]
is
None
:
o
[
i
]
=
prev_val
else
:
o
[
i
]
+=
prev_val
return
out
class
ControlNet
(
ControlBase
):
def
__init__
(
self
,
control_model
,
global_average_pooling
=
False
,
device
=
None
):
super
().
__init__
(
device
)
...
...
@@ -811,32 +857,7 @@ class ControlNet(ControlBase):
if
y
is
not
None
:
y
=
y
.
to
(
self
.
control_model
.
dtype
)
control
=
self
.
control_model
(
x
=
x_noisy
.
to
(
self
.
control_model
.
dtype
),
hint
=
self
.
cond_hint
,
timesteps
=
t
,
context
=
context
.
to
(
self
.
control_model
.
dtype
),
y
=
y
)
out
=
{
'middle'
:[],
'output'
:
[]}
for
i
in
range
(
len
(
control
)):
if
i
==
(
len
(
control
)
-
1
):
key
=
'middle'
index
=
0
else
:
key
=
'output'
index
=
i
x
=
control
[
i
]
if
self
.
global_average_pooling
:
x
=
torch
.
mean
(
x
,
dim
=
(
2
,
3
),
keepdim
=
True
).
repeat
(
1
,
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
*=
self
.
strength
if
x
.
dtype
!=
output_dtype
:
x
=
x
.
to
(
output_dtype
)
if
control_prev
is
not
None
and
key
in
control_prev
:
prev
=
control_prev
[
key
][
index
]
if
prev
is
not
None
:
x
+=
prev
out
[
key
].
append
(
x
)
if
control_prev
is
not
None
and
'input'
in
control_prev
:
out
[
'input'
]
=
control_prev
[
'input'
]
return
out
return
self
.
control_merge
(
None
,
control
,
control_prev
,
output_dtype
)
def
copy
(
self
):
c
=
ControlNet
(
self
.
control_model
,
global_average_pooling
=
self
.
global_average_pooling
)
...
...
@@ -1101,37 +1122,13 @@ class T2IAdapter(ControlBase):
if
x_noisy
.
shape
[
0
]
!=
self
.
cond_hint
.
shape
[
0
]:
self
.
cond_hint
=
broadcast_image_to
(
self
.
cond_hint
,
x_noisy
.
shape
[
0
],
batched_number
)
if
self
.
control_input
is
None
:
self
.
t2i_model
.
to
(
x_noisy
.
dtype
)
self
.
t2i_model
.
to
(
self
.
device
)
self
.
control_input
=
self
.
t2i_model
(
self
.
cond_hint
)
self
.
control_input
=
self
.
t2i_model
(
self
.
cond_hint
.
to
(
x_noisy
.
dtype
)
)
self
.
t2i_model
.
cpu
()
output_dtype
=
x_noisy
.
dtype
out
=
{
'input'
:[]}
for
i
in
range
(
len
(
self
.
control_input
)):
key
=
'input'
x
=
self
.
control_input
[
i
]
*
self
.
strength
if
x
.
dtype
!=
output_dtype
:
x
=
x
.
to
(
output_dtype
)
if
control_prev
is
not
None
and
key
in
control_prev
:
index
=
len
(
control_prev
[
key
])
-
i
*
3
-
3
prev
=
control_prev
[
key
][
index
]
if
prev
is
not
None
:
x
+=
prev
out
[
key
].
insert
(
0
,
None
)
out
[
key
].
insert
(
0
,
None
)
out
[
key
].
insert
(
0
,
x
)
if
control_prev
is
not
None
and
'input'
in
control_prev
:
for
i
in
range
(
len
(
out
[
'input'
])):
if
out
[
'input'
][
i
]
is
None
:
out
[
'input'
][
i
]
=
control_prev
[
'input'
][
i
]
if
control_prev
is
not
None
and
'middle'
in
control_prev
:
out
[
'middle'
]
=
control_prev
[
'middle'
]
if
control_prev
is
not
None
and
'output'
in
control_prev
:
out
[
'output'
]
=
control_prev
[
'output'
]
return
out
control_input
=
list
(
map
(
lambda
a
:
None
if
a
is
None
else
a
.
clone
(),
self
.
control_input
))
return
self
.
control_merge
(
control_input
,
None
,
control_prev
,
x_noisy
.
dtype
)
def
copy
(
self
):
c
=
T2IAdapter
(
self
.
t2i_model
,
self
.
channels_in
)
...
...
comfy/t2i_adapter/adapter.py
View file @
cf5ae469
...
...
@@ -128,6 +128,8 @@ class Adapter(nn.Module):
for
j
in
range
(
self
.
nums_rb
):
idx
=
i
*
self
.
nums_rb
+
j
x
=
self
.
body
[
idx
](
x
)
features
.
append
(
None
)
features
.
append
(
None
)
features
.
append
(
x
)
return
features
...
...
@@ -259,6 +261,8 @@ class Adapter_light(nn.Module):
features
=
[]
for
i
in
range
(
len
(
self
.
channels
)):
x
=
self
.
body
[
i
](
x
)
features
.
append
(
None
)
features
.
append
(
None
)
features
.
append
(
x
)
return
features
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment