Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
700491d8
Commit
700491d8
authored
Jun 03, 2023
by
comfyanonymous
Browse files
Implement global average pooling for controlnet.
parent
66e588d8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
3 deletions
+11
-3
comfy/sd.py
comfy/sd.py
+11
-3
No files found.
comfy/sd.py
View file @
700491d8
...
@@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
...
@@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
return
torch
.
cat
([
tensor
]
*
batched_number
,
dim
=
0
)
return
torch
.
cat
([
tensor
]
*
batched_number
,
dim
=
0
)
class
ControlNet
:
class
ControlNet
:
def
__init__
(
self
,
control_model
,
device
=
None
):
def
__init__
(
self
,
control_model
,
global_average_pooling
=
False
,
device
=
None
):
self
.
control_model
=
control_model
self
.
control_model
=
control_model
self
.
cond_hint_original
=
None
self
.
cond_hint_original
=
None
self
.
cond_hint
=
None
self
.
cond_hint
=
None
...
@@ -630,6 +630,7 @@ class ControlNet:
...
@@ -630,6 +630,7 @@ class ControlNet:
device
=
model_management
.
get_torch_device
()
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
self
.
device
=
device
self
.
previous_controlnet
=
None
self
.
previous_controlnet
=
None
self
.
global_average_pooling
=
global_average_pooling
def
get_control
(
self
,
x_noisy
,
t
,
cond_txt
,
batched_number
):
def
get_control
(
self
,
x_noisy
,
t
,
cond_txt
,
batched_number
):
control_prev
=
None
control_prev
=
None
...
@@ -665,6 +666,9 @@ class ControlNet:
...
@@ -665,6 +666,9 @@ class ControlNet:
key
=
'output'
key
=
'output'
index
=
i
index
=
i
x
=
control
[
i
]
x
=
control
[
i
]
if
self
.
global_average_pooling
:
x
=
torch
.
mean
(
x
,
dim
=
(
2
,
3
),
keepdim
=
True
).
repeat
(
1
,
1
,
x
.
shape
[
2
],
x
.
shape
[
3
])
x
*=
self
.
strength
x
*=
self
.
strength
if
x
.
dtype
!=
output_dtype
and
not
autocast_enabled
:
if
x
.
dtype
!=
output_dtype
and
not
autocast_enabled
:
x
=
x
.
to
(
output_dtype
)
x
=
x
.
to
(
output_dtype
)
...
@@ -695,7 +699,7 @@ class ControlNet:
...
@@ -695,7 +699,7 @@ class ControlNet:
self
.
cond_hint
=
None
self
.
cond_hint
=
None
def
copy
(
self
):
def
copy
(
self
):
c
=
ControlNet
(
self
.
control_model
)
c
=
ControlNet
(
self
.
control_model
,
global_average_pooling
=
self
.
global_average_pooling
)
c
.
cond_hint_original
=
self
.
cond_hint_original
c
.
cond_hint_original
=
self
.
cond_hint_original
c
.
strength
=
self
.
strength
c
.
strength
=
self
.
strength
return
c
return
c
...
@@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None):
...
@@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None):
if
use_fp16
:
if
use_fp16
:
control_model
=
control_model
.
half
()
control_model
=
control_model
.
half
()
control
=
ControlNet
(
control_model
)
global_average_pooling
=
False
if
ckpt_path
.
endswith
(
"_shuffle.pth"
)
or
ckpt_path
.
endswith
(
"_shuffle.safetensors"
)
or
ckpt_path
.
endswith
(
"_shuffle_fp16.safetensors"
):
#TODO: smarter way of enabling global_average_pooling
global_average_pooling
=
True
control
=
ControlNet
(
control_model
,
global_average_pooling
=
global_average_pooling
)
return
control
return
control
class
T2IAdapter
:
class
T2IAdapter
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment