Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
32b7e7e7
Commit
32b7e7e7
authored
Dec 12, 2023
by
comfyanonymous
Browse files
Add manual cast to controlnet.
parent
3152023f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
36 deletions
+46
-36
comfy/cldm/cldm.py
comfy/cldm/cldm.py
+14
-14
comfy/controlnet.py
comfy/controlnet.py
+32
-22
No files found.
comfy/cldm/cldm.py
View file @
32b7e7e7
...
...
@@ -141,24 +141,24 @@ class ControlNet(nn.Module):
)
]
)
self
.
zero_convs
=
nn
.
ModuleList
([
self
.
make_zero_conv
(
model_channels
,
operations
=
operations
)])
self
.
zero_convs
=
nn
.
ModuleList
([
self
.
make_zero_conv
(
model_channels
,
operations
=
operations
,
dtype
=
self
.
dtype
,
device
=
device
)])
self
.
input_hint_block
=
TimestepEmbedSequential
(
operations
.
conv_nd
(
dims
,
hint_channels
,
16
,
3
,
padding
=
1
),
operations
.
conv_nd
(
dims
,
hint_channels
,
16
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
operations
.
conv_nd
(
dims
,
16
,
16
,
3
,
padding
=
1
),
operations
.
conv_nd
(
dims
,
16
,
16
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
operations
.
conv_nd
(
dims
,
16
,
32
,
3
,
padding
=
1
,
stride
=
2
),
operations
.
conv_nd
(
dims
,
16
,
32
,
3
,
padding
=
1
,
stride
=
2
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
operations
.
conv_nd
(
dims
,
32
,
32
,
3
,
padding
=
1
),
operations
.
conv_nd
(
dims
,
32
,
32
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
operations
.
conv_nd
(
dims
,
32
,
96
,
3
,
padding
=
1
,
stride
=
2
),
operations
.
conv_nd
(
dims
,
32
,
96
,
3
,
padding
=
1
,
stride
=
2
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
operations
.
conv_nd
(
dims
,
96
,
96
,
3
,
padding
=
1
),
operations
.
conv_nd
(
dims
,
96
,
96
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
operations
.
conv_nd
(
dims
,
96
,
256
,
3
,
padding
=
1
,
stride
=
2
),
operations
.
conv_nd
(
dims
,
96
,
256
,
3
,
padding
=
1
,
stride
=
2
,
dtype
=
self
.
dtype
,
device
=
device
),
nn
.
SiLU
(),
zero_module
(
operations
.
conv_nd
(
dims
,
256
,
model_channels
,
3
,
padding
=
1
)
)
operations
.
conv_nd
(
dims
,
256
,
model_channels
,
3
,
padding
=
1
,
dtype
=
self
.
dtype
,
device
=
device
)
)
self
.
_feature_size
=
model_channels
...
...
@@ -206,7 +206,7 @@ class ControlNet(nn.Module):
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
zero_convs
.
append
(
self
.
make_zero_conv
(
ch
,
operations
=
operations
))
self
.
zero_convs
.
append
(
self
.
make_zero_conv
(
ch
,
operations
=
operations
,
dtype
=
self
.
dtype
,
device
=
device
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
...
...
@@ -234,7 +234,7 @@ class ControlNet(nn.Module):
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
self
.
zero_convs
.
append
(
self
.
make_zero_conv
(
ch
,
operations
=
operations
))
self
.
zero_convs
.
append
(
self
.
make_zero_conv
(
ch
,
operations
=
operations
,
dtype
=
self
.
dtype
,
device
=
device
))
ds
*=
2
self
.
_feature_size
+=
ch
...
...
@@ -276,11 +276,11 @@ class ControlNet(nn.Module):
operations
=
operations
)]
self
.
middle_block
=
TimestepEmbedSequential
(
*
mid_block
)
self
.
middle_block_out
=
self
.
make_zero_conv
(
ch
,
operations
=
operations
)
self
.
middle_block_out
=
self
.
make_zero_conv
(
ch
,
operations
=
operations
,
dtype
=
self
.
dtype
,
device
=
device
)
self
.
_feature_size
+=
ch
def
make_zero_conv
(
self
,
channels
,
operations
=
None
):
return
TimestepEmbedSequential
(
zero_module
(
operations
.
conv_nd
(
self
.
dims
,
channels
,
channels
,
1
,
padding
=
0
)
))
def
make_zero_conv
(
self
,
channels
,
operations
=
None
,
dtype
=
None
,
device
=
None
):
return
TimestepEmbedSequential
(
operations
.
conv_nd
(
self
.
dims
,
channels
,
channels
,
1
,
padding
=
0
,
dtype
=
dtype
,
device
=
device
))
def
forward
(
self
,
x
,
hint
,
timesteps
,
context
,
y
=
None
,
**
kwargs
):
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
,
repeat_only
=
False
).
to
(
x
.
dtype
)
...
...
comfy/controlnet.py
View file @
32b7e7e7
...
...
@@ -36,13 +36,13 @@ class ControlBase:
self
.
cond_hint
=
None
self
.
strength
=
1.0
self
.
timestep_percent_range
=
(
0.0
,
1.0
)
self
.
global_average_pooling
=
False
self
.
timestep_range
=
None
if
device
is
None
:
device
=
comfy
.
model_management
.
get_torch_device
()
self
.
device
=
device
self
.
previous_controlnet
=
None
self
.
global_average_pooling
=
False
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
,
timestep_percent_range
=
(
0.0
,
1.0
)):
self
.
cond_hint_original
=
cond_hint
...
...
@@ -77,6 +77,7 @@ class ControlBase:
c
.
cond_hint_original
=
self
.
cond_hint_original
c
.
strength
=
self
.
strength
c
.
timestep_percent_range
=
self
.
timestep_percent_range
c
.
global_average_pooling
=
self
.
global_average_pooling
def
inference_memory_requirements
(
self
,
dtype
):
if
self
.
previous_controlnet
is
not
None
:
...
...
@@ -129,12 +130,14 @@ class ControlBase:
return
out
class
ControlNet
(
ControlBase
):
def
__init__
(
self
,
control_model
,
global_average_pooling
=
False
,
device
=
None
):
def
__init__
(
self
,
control_model
,
global_average_pooling
=
False
,
device
=
None
,
load_device
=
None
,
manual_cast_dtype
=
None
):
super
().
__init__
(
device
)
self
.
control_model
=
control_model
self
.
control_model_wrapped
=
comfy
.
model_patcher
.
ModelPatcher
(
self
.
control_model
,
load_device
=
comfy
.
model_management
.
get_torch_device
(),
offload_device
=
comfy
.
model_management
.
unet_offload_device
())
self
.
load_device
=
load_device
self
.
control_model_wrapped
=
comfy
.
model_patcher
.
ModelPatcher
(
self
.
control_model
,
load_device
=
load_device
,
offload_device
=
comfy
.
model_management
.
unet_offload_device
())
self
.
global_average_pooling
=
global_average_pooling
self
.
model_sampling_current
=
None
self
.
manual_cast_dtype
=
manual_cast_dtype
def
get_control
(
self
,
x_noisy
,
t
,
cond
,
batched_number
):
control_prev
=
None
...
...
@@ -149,11 +152,8 @@ class ControlNet(ControlBase):
return
None
dtype
=
self
.
control_model
.
dtype
if
comfy
.
model_management
.
supports_dtype
(
self
.
device
,
dtype
):
precision_scope
=
lambda
a
:
contextlib
.
nullcontext
(
a
)
else
:
precision_scope
=
torch
.
autocast
dtype
=
torch
.
float32
if
self
.
manual_cast_dtype
is
not
None
:
dtype
=
self
.
manual_cast_dtype
output_dtype
=
x_noisy
.
dtype
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
...
...
@@ -171,12 +171,11 @@ class ControlNet(ControlBase):
timestep
=
self
.
model_sampling_current
.
timestep
(
t
)
x_noisy
=
self
.
model_sampling_current
.
calculate_input
(
t
,
x_noisy
)
with
precision_scope
(
comfy
.
model_management
.
get_autocast_device
(
self
.
device
)):
control
=
self
.
control_model
(
x
=
x_noisy
.
to
(
dtype
),
hint
=
self
.
cond_hint
,
timesteps
=
timestep
.
float
(),
context
=
context
.
to
(
dtype
),
y
=
y
)
return
self
.
control_merge
(
None
,
control
,
control_prev
,
output_dtype
)
def
copy
(
self
):
c
=
ControlNet
(
self
.
control_model
,
global_average_pooling
=
self
.
global_average_pooling
)
c
=
ControlNet
(
self
.
control_model
,
global_average_pooling
=
self
.
global_average_pooling
,
load_device
=
self
.
load_device
,
manual_cast_dtype
=
self
.
manual_cast_dtype
)
self
.
copy_to
(
c
)
return
c
...
...
@@ -207,10 +206,11 @@ class ControlLoraOps:
self
.
bias
=
None
def
forward
(
self
,
input
):
weight
,
bias
=
comfy
.
ops
.
cast_bias_weight
(
self
,
input
)
if
self
.
up
is
not
None
:
return
torch
.
nn
.
functional
.
linear
(
input
,
self
.
weight
.
to
(
dtype
=
input
.
dtype
,
device
=
input
.
device
)
+
(
torch
.
mm
(
self
.
up
.
flatten
(
start_dim
=
1
),
self
.
down
.
flatten
(
start_dim
=
1
))).
reshape
(
self
.
weight
.
shape
).
type
(
input
.
dtype
),
self
.
bias
)
return
torch
.
nn
.
functional
.
linear
(
input
,
weight
+
(
torch
.
mm
(
self
.
up
.
flatten
(
start_dim
=
1
),
self
.
down
.
flatten
(
start_dim
=
1
))).
reshape
(
self
.
weight
.
shape
).
type
(
input
.
dtype
),
bias
)
else
:
return
torch
.
nn
.
functional
.
linear
(
input
,
self
.
weight
.
to
(
dtype
=
input
.
dtype
,
device
=
input
.
device
),
self
.
bias
)
return
torch
.
nn
.
functional
.
linear
(
input
,
weight
,
bias
)
class
Conv2d
(
torch
.
nn
.
Module
):
def
__init__
(
...
...
@@ -246,10 +246,11 @@ class ControlLoraOps:
def
forward
(
self
,
input
):
weight
,
bias
=
comfy
.
ops
.
cast_bias_weight
(
self
,
input
)
if
self
.
up
is
not
None
:
return
torch
.
nn
.
functional
.
conv2d
(
input
,
self
.
weight
.
to
(
dtype
=
input
.
dtype
,
device
=
input
.
device
)
+
(
torch
.
mm
(
self
.
up
.
flatten
(
start_dim
=
1
),
self
.
down
.
flatten
(
start_dim
=
1
))).
reshape
(
self
.
weight
.
shape
).
type
(
input
.
dtype
),
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
)
return
torch
.
nn
.
functional
.
conv2d
(
input
,
weight
+
(
torch
.
mm
(
self
.
up
.
flatten
(
start_dim
=
1
),
self
.
down
.
flatten
(
start_dim
=
1
))).
reshape
(
self
.
weight
.
shape
).
type
(
input
.
dtype
),
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
)
else
:
return
torch
.
nn
.
functional
.
conv2d
(
input
,
self
.
weight
.
to
(
dtype
=
input
.
dtype
,
device
=
input
.
device
),
self
.
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
)
return
torch
.
nn
.
functional
.
conv2d
(
input
,
weight
,
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
)
class
ControlLora
(
ControlNet
):
...
...
@@ -263,12 +264,19 @@ class ControlLora(ControlNet):
controlnet_config
=
model
.
model_config
.
unet_config
.
copy
()
controlnet_config
.
pop
(
"out_channels"
)
controlnet_config
[
"hint_channels"
]
=
self
.
control_weights
[
"input_hint_block.0.weight"
].
shape
[
1
]
self
.
manual_cast_dtype
=
model
.
manual_cast_dtype
dtype
=
model
.
get_dtype
()
if
self
.
manual_cast_dtype
is
None
:
class
control_lora_ops
(
ControlLoraOps
,
comfy
.
ops
.
disable_weight_init
):
pass
else
:
class
control_lora_ops
(
ControlLoraOps
,
comfy
.
ops
.
manual_cast
):
pass
dtype
=
self
.
manual_cast_dtype
controlnet_config
[
"operations"
]
=
control_lora_ops
controlnet_config
[
"dtype"
]
=
dtype
self
.
control_model
=
comfy
.
cldm
.
cldm
.
ControlNet
(
**
controlnet_config
)
dtype
=
model
.
get_dtype
()
self
.
control_model
.
to
(
dtype
)
self
.
control_model
.
to
(
comfy
.
model_management
.
get_torch_device
())
diffusion_model
=
model
.
diffusion_model
sd
=
diffusion_model
.
state_dict
()
...
...
@@ -372,6 +380,10 @@ def load_controlnet(ckpt_path, model=None):
if
controlnet_config
is
None
:
unet_dtype
=
comfy
.
model_management
.
unet_dtype
()
controlnet_config
=
comfy
.
model_detection
.
model_config_from_unet
(
controlnet_data
,
prefix
,
unet_dtype
,
True
).
unet_config
load_device
=
comfy
.
model_management
.
get_torch_device
()
manual_cast_dtype
=
comfy
.
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
)
if
manual_cast_dtype
is
not
None
:
controlnet_config
[
"operations"
]
=
comfy
.
ops
.
manual_cast
controlnet_config
.
pop
(
"out_channels"
)
controlnet_config
[
"hint_channels"
]
=
controlnet_data
[
"{}input_hint_block.0.weight"
.
format
(
prefix
)].
shape
[
1
]
control_model
=
comfy
.
cldm
.
cldm
.
ControlNet
(
**
controlnet_config
)
...
...
@@ -400,14 +412,12 @@ def load_controlnet(ckpt_path, model=None):
missing
,
unexpected
=
control_model
.
load_state_dict
(
controlnet_data
,
strict
=
False
)
print
(
missing
,
unexpected
)
control_model
=
control_model
.
to
(
unet_dtype
)
global_average_pooling
=
False
filename
=
os
.
path
.
splitext
(
ckpt_path
)[
0
]
if
filename
.
endswith
(
"_shuffle"
)
or
filename
.
endswith
(
"_shuffle_fp16"
):
#TODO: smarter way of enabling global_average_pooling
global_average_pooling
=
True
control
=
ControlNet
(
control_model
,
global_average_pooling
=
global_average_pooling
)
control
=
ControlNet
(
control_model
,
global_average_pooling
=
global_average_pooling
,
load_device
=
load_device
,
manual_cast_dtype
=
manual_cast_dtype
)
return
control
class
T2IAdapter
(
ControlBase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment