Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
4efa67fa
Commit
4efa67fa
authored
Feb 16, 2023
by
comfyanonymous
Browse files
Add ControlNet support.
parent
bc69fb52
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
580 additions
and
63 deletions
+580
-63
comfy/cldm/cldm.py
comfy/cldm/cldm.py
+286
-0
comfy/extra_samplers/uni_pc.py
comfy/extra_samplers/uni_pc.py
+2
-2
comfy/ldm/models/diffusion/ddpm.py
comfy/ldm/models/diffusion/ddpm.py
+9
-9
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+8
-2
comfy/model_management.py
comfy/model_management.py
+16
-1
comfy/samplers.py
comfy/samplers.py
+100
-21
comfy/sd.py
comfy/sd.py
+76
-0
comfy/utils.py
comfy/utils.py
+18
-0
nodes.py
nodes.py
+65
-28
No files found.
comfy/cldm/cldm.py
0 → 100644
View file @
4efa67fa
#taken from: https://github.com/lllyasviel/ControlNet
#and modified
import
einops
import
torch
import
torch
as
th
import
torch.nn
as
nn
from
ldm.modules.diffusionmodules.util
import
(
conv_nd
,
linear
,
zero_module
,
timestep_embedding
,
)
from
einops
import
rearrange
,
repeat
from
torchvision.utils
import
make_grid
from
ldm.modules.attention
import
SpatialTransformer
from
ldm.modules.diffusionmodules.openaimodel
import
UNetModel
,
TimestepEmbedSequential
,
ResBlock
,
Downsample
,
AttentionBlock
from
ldm.models.diffusion.ddpm
import
LatentDiffusion
from
ldm.util
import
log_txt_as_img
,
exists
,
instantiate_from_config
class
ControlledUnetModel
(
UNetModel
):
#implemented in the ldm unet
pass
class
ControlNet
(
nn
.
Module
):
def
__init__
(
self
,
image_size
,
in_channels
,
model_channels
,
hint_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=-
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
use_spatial_transformer
=
False
,
# custom transformer support
transformer_depth
=
1
,
# custom transformer support
context_dim
=
None
,
# custom transformer support
n_embed
=
None
,
# custom support for prediction of discrete ids into codebook of first stage vq model
legacy
=
True
,
disable_self_attentions
=
None
,
num_attention_blocks
=
None
,
disable_middle_self_attn
=
False
,
use_linear_in_transformer
=
False
,
):
super
().
__init__
()
if
use_spatial_transformer
:
assert
context_dim
is
not
None
,
'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if
context_dim
is
not
None
:
assert
use_spatial_transformer
,
'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from
omegaconf.listconfig
import
ListConfig
if
type
(
context_dim
)
==
ListConfig
:
context_dim
=
list
(
context_dim
)
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
if
num_heads
==
-
1
:
assert
num_head_channels
!=
-
1
,
'Either num_heads or num_head_channels has to be set'
if
num_head_channels
==
-
1
:
assert
num_heads
!=
-
1
,
'Either num_heads or num_head_channels has to be set'
self
.
dims
=
dims
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
if
isinstance
(
num_res_blocks
,
int
):
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
else
:
if
len
(
num_res_blocks
)
!=
len
(
channel_mult
):
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
self
.
num_res_blocks
=
num_res_blocks
if
disable_self_attentions
is
not
None
:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
if
num_attention_blocks
is
not
None
:
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
))))
print
(
f
"Constructor of UNetModel received num_attention_blocks=
{
num_attention_blocks
}
. "
f
"This option has LESS priority than attention_resolutions
{
attention_resolutions
}
, "
f
"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f
"attention will still not be set."
)
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
th
.
float16
if
use_fp16
else
th
.
float32
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
predict_codebook_ids
=
n_embed
is
not
None
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
)
)
]
)
self
.
zero_convs
=
nn
.
ModuleList
([
self
.
make_zero_conv
(
model_channels
)])
self
.
input_hint_block
=
TimestepEmbedSequential
(
conv_nd
(
dims
,
hint_channels
,
16
,
3
,
padding
=
1
),
nn
.
SiLU
(),
conv_nd
(
dims
,
16
,
16
,
3
,
padding
=
1
),
nn
.
SiLU
(),
conv_nd
(
dims
,
16
,
32
,
3
,
padding
=
1
,
stride
=
2
),
nn
.
SiLU
(),
conv_nd
(
dims
,
32
,
32
,
3
,
padding
=
1
),
nn
.
SiLU
(),
conv_nd
(
dims
,
32
,
96
,
3
,
padding
=
1
,
stride
=
2
),
nn
.
SiLU
(),
conv_nd
(
dims
,
96
,
96
,
3
,
padding
=
1
),
nn
.
SiLU
(),
conv_nd
(
dims
,
96
,
256
,
3
,
padding
=
1
,
stride
=
2
),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
256
,
model_channels
,
3
,
padding
=
1
))
)
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
nr
in
range
(
self
.
num_res_blocks
[
level
]):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
if
exists
(
disable_self_attentions
):
disabled_sa
=
disable_self_attentions
[
level
]
else
:
disabled_sa
=
False
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
zero_convs
.
append
(
self
.
make_zero_conv
(
ch
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
self
.
zero_convs
.
append
(
self
.
make_zero_conv
(
ch
))
ds
*=
2
self
.
_feature_size
+=
ch
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
)
self
.
middle_block_out
=
self
.
make_zero_conv
(
ch
)
self
.
_feature_size
+=
ch
def
make_zero_conv
(
self
,
channels
):
return
TimestepEmbedSequential
(
zero_module
(
conv_nd
(
self
.
dims
,
channels
,
channels
,
1
,
padding
=
0
)))
def
forward
(
self
,
x
,
hint
,
timesteps
,
context
,
**
kwargs
):
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
,
repeat_only
=
False
)
emb
=
self
.
time_embed
(
t_emb
)
guided_hint
=
self
.
input_hint_block
(
hint
,
emb
,
context
)
outs
=
[]
h
=
x
.
type
(
self
.
dtype
)
for
module
,
zero_conv
in
zip
(
self
.
input_blocks
,
self
.
zero_convs
):
if
guided_hint
is
not
None
:
h
=
module
(
h
,
emb
,
context
)
h
+=
guided_hint
guided_hint
=
None
else
:
h
=
module
(
h
,
emb
,
context
)
outs
.
append
(
zero_conv
(
h
,
emb
,
context
))
h
=
self
.
middle_block
(
h
,
emb
,
context
)
outs
.
append
(
self
.
middle_block_out
(
h
,
emb
,
context
))
return
outs
comfy/extra_samplers/uni_pc.py
View file @
4efa67fa
...
@@ -856,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
...
@@ -856,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
device
=
noise
.
device
device
=
noise
.
device
if
model
.
inner_
model
.
parameterization
==
"v"
:
if
model
.
parameterization
==
"v"
:
model_type
=
"v"
model_type
=
"v"
else
:
else
:
model_type
=
"noise"
model_type
=
"noise"
model_fn
=
model_wrapper
(
model_fn
=
model_wrapper
(
model
.
inner_model
.
apply_model
,
model
.
inner_model
.
inner_model
.
apply_model
,
sampling_function
,
sampling_function
,
ns
,
ns
,
model_type
=
model_type
,
model_type
=
model_type
,
...
...
comfy/ldm/models/diffusion/ddpm.py
View file @
4efa67fa
...
@@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module):
...
@@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module):
self
.
conditioning_key
=
conditioning_key
self
.
conditioning_key
=
conditioning_key
assert
self
.
conditioning_key
in
[
None
,
'concat'
,
'crossattn'
,
'hybrid'
,
'adm'
,
'hybrid-adm'
,
'crossattn-adm'
]
assert
self
.
conditioning_key
in
[
None
,
'concat'
,
'crossattn'
,
'hybrid'
,
'adm'
,
'hybrid-adm'
,
'crossattn-adm'
]
def
forward
(
self
,
x
,
t
,
c_concat
:
list
=
None
,
c_crossattn
:
list
=
None
,
c_adm
=
None
):
def
forward
(
self
,
x
,
t
,
c_concat
:
list
=
None
,
c_crossattn
:
list
=
None
,
c_adm
=
None
,
control
=
None
):
if
self
.
conditioning_key
is
None
:
if
self
.
conditioning_key
is
None
:
out
=
self
.
diffusion_model
(
x
,
t
)
out
=
self
.
diffusion_model
(
x
,
t
,
control
=
control
)
elif
self
.
conditioning_key
==
'concat'
:
elif
self
.
conditioning_key
==
'concat'
:
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
out
=
self
.
diffusion_model
(
xc
,
t
)
out
=
self
.
diffusion_model
(
xc
,
t
,
control
=
control
)
elif
self
.
conditioning_key
==
'crossattn'
:
elif
self
.
conditioning_key
==
'crossattn'
:
if
not
self
.
sequential_cross_attn
:
if
not
self
.
sequential_cross_attn
:
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
...
@@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module):
...
@@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module):
# TorchScript changes names of the arguments
# TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce
# with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out
=
self
.
scripted_diffusion_model
(
x
,
t
,
cc
)
out
=
self
.
scripted_diffusion_model
(
x
,
t
,
cc
,
control
=
control
)
else
:
else
:
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
control
=
control
)
elif
self
.
conditioning_key
==
'hybrid'
:
elif
self
.
conditioning_key
==
'hybrid'
:
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
control
=
control
)
elif
self
.
conditioning_key
==
'hybrid-adm'
:
elif
self
.
conditioning_key
==
'hybrid-adm'
:
assert
c_adm
is
not
None
assert
c_adm
is
not
None
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
y
=
c_adm
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
)
elif
self
.
conditioning_key
==
'crossattn-adm'
:
elif
self
.
conditioning_key
==
'crossattn-adm'
:
assert
c_adm
is
not
None
assert
c_adm
is
not
None
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
y
=
c_adm
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
)
elif
self
.
conditioning_key
==
'adm'
:
elif
self
.
conditioning_key
==
'adm'
:
cc
=
c_crossattn
[
0
]
cc
=
c_crossattn
[
0
]
out
=
self
.
diffusion_model
(
x
,
t
,
y
=
cc
)
out
=
self
.
diffusion_model
(
x
,
t
,
y
=
cc
,
control
=
control
)
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
4efa67fa
...
@@ -753,7 +753,7 @@ class UNetModel(nn.Module):
...
@@ -753,7 +753,7 @@ class UNetModel(nn.Module):
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
control
=
None
,
**
kwargs
):
"""
"""
Apply the model to an input batch.
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param x: an [N x C x ...] Tensor of inputs.
...
@@ -778,8 +778,14 @@ class UNetModel(nn.Module):
...
@@ -778,8 +778,14 @@ class UNetModel(nn.Module):
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
,
context
)
h
=
self
.
middle_block
(
h
,
emb
,
context
)
if
control
is
not
None
:
h
+=
control
.
pop
()
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
th
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
hsp
=
hs
.
pop
()
if
control
is
not
None
:
hsp
+=
control
.
pop
()
h
=
th
.
cat
([
h
,
hsp
],
dim
=
1
)
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
)
h
=
h
.
type
(
x
.
dtype
)
h
=
h
.
type
(
x
.
dtype
)
if
self
.
predict_codebook_ids
:
if
self
.
predict_codebook_ids
:
...
...
comfy/model_management.py
View file @
4efa67fa
...
@@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s
...
@@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s
current_loaded_model
=
None
current_loaded_model
=
None
current_gpu_controlnets
=
[]
model_accelerated
=
False
model_accelerated
=
False
...
@@ -56,6 +56,7 @@ model_accelerated = False
...
@@ -56,6 +56,7 @@ model_accelerated = False
def
unload_model
():
def
unload_model
():
global
current_loaded_model
global
current_loaded_model
global
model_accelerated
global
model_accelerated
global
current_gpu_controlnets
if
current_loaded_model
is
not
None
:
if
current_loaded_model
is
not
None
:
if
model_accelerated
:
if
model_accelerated
:
accelerate
.
hooks
.
remove_hook_from_submodules
(
current_loaded_model
.
model
)
accelerate
.
hooks
.
remove_hook_from_submodules
(
current_loaded_model
.
model
)
...
@@ -64,6 +65,10 @@ def unload_model():
...
@@ -64,6 +65,10 @@ def unload_model():
current_loaded_model
.
model
.
cpu
()
current_loaded_model
.
model
.
cpu
()
current_loaded_model
.
unpatch_model
()
current_loaded_model
.
unpatch_model
()
current_loaded_model
=
None
current_loaded_model
=
None
if
len
(
current_gpu_controlnets
)
>
0
:
for
n
in
current_gpu_controlnets
:
n
.
cpu
()
current_gpu_controlnets
=
[]
def
load_model_gpu
(
model
):
def
load_model_gpu
(
model
):
...
@@ -95,6 +100,16 @@ def load_model_gpu(model):
...
@@ -95,6 +100,16 @@ def load_model_gpu(model):
model_accelerated
=
True
model_accelerated
=
True
return
current_loaded_model
return
current_loaded_model
def
load_controlnet_gpu
(
models
):
global
current_gpu_controlnets
for
m
in
current_gpu_controlnets
:
if
m
not
in
models
:
m
.
cpu
()
current_gpu_controlnets
=
[]
for
m
in
models
:
current_gpu_controlnets
.
append
(
m
.
cuda
())
def
get_free_memory
():
def
get_free_memory
():
dev
=
torch
.
cuda
.
current_device
()
dev
=
torch
.
cuda
.
current_device
()
...
...
comfy/samplers.py
View file @
4efa67fa
...
@@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module):
...
@@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module):
uncond
=
self
.
inner_model
(
x
,
sigma
,
cond
=
uncond
)
uncond
=
self
.
inner_model
(
x
,
sigma
,
cond
=
uncond
)
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
def
sampling_function
(
model_function
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
cond_concat
=
None
):
def
get_area_and_mult
(
cond
,
x_in
,
cond_concat_in
):
#The main sampling function shared by all the samplers
#Returns predicted noise
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
=
None
):
def
get_area_and_mult
(
cond
,
x_in
,
cond_concat_in
,
timestep_in
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
strength
=
1.0
min_sigma
=
0.0
max_sigma
=
999.0
if
'area'
in
cond
[
1
]:
if
'area'
in
cond
[
1
]:
area
=
cond
[
1
][
'area'
]
area
=
cond
[
1
][
'area'
]
if
'strength'
in
cond
[
1
]:
if
'strength'
in
cond
[
1
]:
...
@@ -56,9 +57,15 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
...
@@ -56,9 +57,15 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
cr
=
x
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
cr
=
x
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
cropped
.
append
(
cr
)
cropped
.
append
(
cr
)
conditionning
[
'c_concat'
]
=
torch
.
cat
(
cropped
,
dim
=
1
)
conditionning
[
'c_concat'
]
=
torch
.
cat
(
cropped
,
dim
=
1
)
return
(
input_x
,
mult
,
conditionning
,
area
)
control
=
None
if
'control'
in
cond
[
1
]:
control
=
cond
[
1
][
'control'
]
return
(
input_x
,
mult
,
conditionning
,
area
,
control
)
def
cond_equal_size
(
c1
,
c2
):
def
cond_equal_size
(
c1
,
c2
):
if
c1
is
c2
:
return
True
if
c1
.
keys
()
!=
c2
.
keys
():
if
c1
.
keys
()
!=
c2
.
keys
():
return
False
return
False
if
'c_crossattn'
in
c1
:
if
'c_crossattn'
in
c1
:
...
@@ -69,6 +76,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
...
@@ -69,6 +76,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
return
False
return
False
return
True
return
True
def
can_concat_cond
(
c1
,
c2
):
if
c1
[
0
].
shape
!=
c2
[
0
].
shape
:
return
False
if
(
c1
[
4
]
is
None
)
!=
(
c2
[
4
]
is
None
):
return
False
if
c1
[
4
]
is
not
None
:
if
c1
[
4
]
is
not
c2
[
4
]:
return
False
return
cond_equal_size
(
c1
[
2
],
c2
[
2
])
def
cond_cat
(
c_list
):
def
cond_cat
(
c_list
):
c_crossattn
=
[]
c_crossattn
=
[]
c_concat
=
[]
c_concat
=
[]
...
@@ -84,7 +102,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
...
@@ -84,7 +102,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
out
[
'c_concat'
]
=
[
torch
.
cat
(
c_concat
)]
out
[
'c_concat'
]
=
[
torch
.
cat
(
c_concat
)]
return
out
return
out
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
sigma
,
max_total_area
,
cond_concat_in
):
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
timestep
,
max_total_area
,
cond_concat_in
):
out_cond
=
torch
.
zeros_like
(
x_in
)
out_cond
=
torch
.
zeros_like
(
x_in
)
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
...
@@ -96,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
...
@@ -96,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
to_run
=
[]
to_run
=
[]
for
x
in
cond
:
for
x
in
cond
:
p
=
get_area_and_mult
(
x
,
x_in
,
cond_concat_in
)
p
=
get_area_and_mult
(
x
,
x_in
,
cond_concat_in
,
timestep
)
if
p
is
None
:
if
p
is
None
:
continue
continue
to_run
+=
[(
p
,
COND
)]
to_run
+=
[(
p
,
COND
)]
for
x
in
uncond
:
for
x
in
uncond
:
p
=
get_area_and_mult
(
x
,
x_in
,
cond_concat_in
)
p
=
get_area_and_mult
(
x
,
x_in
,
cond_concat_in
,
timestep
)
if
p
is
None
:
if
p
is
None
:
continue
continue
...
@@ -113,8 +131,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
...
@@ -113,8 +131,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
first_shape
=
first
[
0
][
0
].
shape
first_shape
=
first
[
0
][
0
].
shape
to_batch_temp
=
[]
to_batch_temp
=
[]
for
x
in
range
(
len
(
to_run
)):
for
x
in
range
(
len
(
to_run
)):
if
to_run
[
x
][
0
][
0
].
shape
==
first_shape
:
if
can_concat_cond
(
to_run
[
x
][
0
],
first
[
0
]):
if
cond_equal_size
(
to_run
[
x
][
0
][
2
],
first
[
0
][
2
]):
to_batch_temp
+=
[
x
]
to_batch_temp
+=
[
x
]
to_batch_temp
.
reverse
()
to_batch_temp
.
reverse
()
...
@@ -131,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
...
@@ -131,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
c
=
[]
c
=
[]
cond_or_uncond
=
[]
cond_or_uncond
=
[]
area
=
[]
area
=
[]
control
=
None
for
x
in
to_batch
:
for
x
in
to_batch
:
o
=
to_run
.
pop
(
x
)
o
=
to_run
.
pop
(
x
)
p
=
o
[
0
]
p
=
o
[
0
]
...
@@ -139,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
...
@@ -139,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
c
+=
[
p
[
2
]]
c
+=
[
p
[
2
]]
area
+=
[
p
[
3
]]
area
+=
[
p
[
3
]]
cond_or_uncond
+=
[
o
[
1
]]
cond_or_uncond
+=
[
o
[
1
]]
control
=
p
[
4
]
batch_chunks
=
len
(
cond_or_uncond
)
batch_chunks
=
len
(
cond_or_uncond
)
input_x
=
torch
.
cat
(
input_x
)
input_x
=
torch
.
cat
(
input_x
)
c
=
cond_cat
(
c
)
c
=
cond_cat
(
c
)
sigma_
=
torch
.
cat
([
sigma
]
*
batch_chunks
)
timestep_
=
torch
.
cat
([
timestep
]
*
batch_chunks
)
if
control
is
not
None
:
c
[
'control'
]
=
control
.
get_control
(
input_x
,
timestep_
,
c
[
'c_crossattn'
])
output
=
model_function
(
input_x
,
sigma
_
,
cond
=
c
).
chunk
(
batch_chunks
)
output
=
model_function
(
input_x
,
timestep
_
,
cond
=
c
).
chunk
(
batch_chunks
)
del
input_x
del
input_x
for
o
in
range
(
batch_chunks
):
for
o
in
range
(
batch_chunks
):
...
@@ -166,10 +188,29 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
...
@@ -166,10 +188,29 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
max_total_area
=
model_management
.
maximum_batch_area
()
max_total_area
=
model_management
.
maximum_batch_area
()
cond
,
uncond
=
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x
,
sigma
,
max_total_area
,
cond_concat
)
cond
,
uncond
=
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x
,
timestep
,
max_total_area
,
cond_concat
)
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
class
CFGDenoiserComplex
(
torch
.
nn
.
Module
):
class
CompVisVDenoiser
(
k_diffusion_external
.
DiscreteVDDPMDenoiser
):
def
__init__
(
self
,
model
,
quantize
=
False
,
device
=
'cpu'
):
super
().
__init__
(
model
,
model
.
alphas_cumprod
,
quantize
=
quantize
)
def
get_v
(
self
,
x
,
t
,
cond
,
**
kwargs
):
return
self
.
inner_model
.
apply_model
(
x
,
t
,
cond
,
**
kwargs
)
class
CFGNoisePredictor
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
inner_model
=
model
self
.
alphas_cumprod
=
model
.
alphas_cumprod
def
apply_model
(
self
,
x
,
timestep
,
cond
,
uncond
,
cond_scale
,
cond_concat
=
None
):
out
=
sampling_function
(
self
.
inner_model
.
apply_model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
)
return
out
class
KSamplerX0Inpaint
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
super
().
__init__
()
super
().
__init__
()
self
.
inner_model
=
model
self
.
inner_model
=
model
...
@@ -177,7 +218,7 @@ class CFGDenoiserComplex(torch.nn.Module):
...
@@ -177,7 +218,7 @@ class CFGDenoiserComplex(torch.nn.Module):
if
denoise_mask
is
not
None
:
if
denoise_mask
is
not
None
:
latent_mask
=
1.
-
denoise_mask
latent_mask
=
1.
-
denoise_mask
x
=
x
*
denoise_mask
+
(
self
.
latent_image
+
self
.
noise
*
sigma
)
*
latent_mask
x
=
x
*
denoise_mask
+
(
self
.
latent_image
+
self
.
noise
*
sigma
)
*
latent_mask
out
=
sampling_function
(
self
.
inner_model
,
x
,
sigma
,
un
cond
,
cond
,
cond_scale
,
cond_concat
)
out
=
self
.
inner_model
(
x
,
sigma
,
cond
=
cond
,
uncond
=
un
cond
,
cond_scale
=
cond_scale
,
cond_concat
=
cond_concat
)
if
denoise_mask
is
not
None
:
if
denoise_mask
is
not
None
:
out
*=
denoise_mask
out
*=
denoise_mask
...
@@ -196,8 +237,6 @@ def simple_scheduler(model, steps):
...
@@ -196,8 +237,6 @@ def simple_scheduler(model, steps):
def
blank_inpaint_image_like
(
latent_image
):
def
blank_inpaint_image_like
(
latent_image
):
blank_image
=
torch
.
ones_like
(
latent_image
)
blank_image
=
torch
.
ones_like
(
latent_image
)
# these are the values for "zero" in pixel space translated to latent space
# these are the values for "zero" in pixel space translated to latent space
# the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE
# unfortunately that gives zero flexibility so I did things like this instead which hopefully works
blank_image
[:,
0
]
*=
0.8223
blank_image
[:,
0
]
*=
0.8223
blank_image
[:,
1
]
*=
-
0.6876
blank_image
[:,
1
]
*=
-
0.6876
blank_image
[:,
2
]
*=
0.6364
blank_image
[:,
2
]
*=
0.6364
...
@@ -234,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c):
...
@@ -234,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c):
n
=
c
[
1
].
copy
()
n
=
c
[
1
].
copy
()
conds
+=
[[
smallest
[
0
],
n
]]
conds
+=
[[
smallest
[
0
],
n
]]
def
apply_control_net_to_equal_area
(
conds
,
uncond
):
cond_cnets
=
[]
cond_other
=
[]
uncond_cnets
=
[]
uncond_other
=
[]
for
t
in
range
(
len
(
conds
)):
x
=
conds
[
t
]
if
'area'
not
in
x
[
1
]:
if
'control'
in
x
[
1
]
and
x
[
1
][
'control'
]
is
not
None
:
cond_cnets
.
append
(
x
[
1
][
'control'
])
else
:
cond_other
.
append
((
x
,
t
))
for
t
in
range
(
len
(
uncond
)):
x
=
uncond
[
t
]
if
'area'
not
in
x
[
1
]:
if
'control'
in
x
[
1
]
and
x
[
1
][
'control'
]
is
not
None
:
uncond_cnets
.
append
(
x
[
1
][
'control'
])
else
:
uncond_other
.
append
((
x
,
t
))
if
len
(
uncond_cnets
)
>
0
:
return
for
x
in
range
(
len
(
cond_cnets
)):
temp
=
uncond_other
[
x
%
len
(
uncond_other
)]
o
=
temp
[
0
]
if
'control'
in
o
[
1
]
and
o
[
1
][
'control'
]
is
not
None
:
n
=
o
[
1
].
copy
()
n
[
'control'
]
=
cond_cnets
[
x
]
uncond
+=
[[
o
[
0
],
n
]]
else
:
n
=
o
[
1
].
copy
()
n
[
'control'
]
=
cond_cnets
[
x
]
uncond
[
temp
[
1
]]
=
[
o
[
0
],
n
]
class
KSampler
:
class
KSampler
:
SCHEDULERS
=
[
"karras"
,
"normal"
,
"simple"
]
SCHEDULERS
=
[
"karras"
,
"normal"
,
"simple"
]
SAMPLERS
=
[
"sample_euler"
,
"sample_euler_ancestral"
,
"sample_heun"
,
"sample_dpm_2"
,
"sample_dpm_2_ancestral"
,
SAMPLERS
=
[
"sample_euler"
,
"sample_euler_ancestral"
,
"sample_heun"
,
"sample_dpm_2"
,
"sample_dpm_2_ancestral"
,
...
@@ -242,11 +317,13 @@ class KSampler:
...
@@ -242,11 +317,13 @@ class KSampler:
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
):
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
):
self
.
model
=
model
self
.
model
=
model
self
.
model_denoise
=
CFGNoisePredictor
(
self
.
model
)
if
self
.
model
.
parameterization
==
"v"
:
if
self
.
model
.
parameterization
==
"v"
:
self
.
model_wrap
=
k_diffusion_external
.
CompVisVDenoiser
(
self
.
model
,
quantize
=
True
)
self
.
model_wrap
=
CompVisVDenoiser
(
self
.
model
_denoise
,
quantize
=
True
)
else
:
else
:
self
.
model_wrap
=
k_diffusion_external
.
CompVisDenoiser
(
self
.
model
,
quantize
=
True
)
self
.
model_wrap
=
k_diffusion_external
.
CompVisDenoiser
(
self
.
model_denoise
,
quantize
=
True
)
self
.
model_k
=
CFGDenoiserComplex
(
self
.
model_wrap
)
self
.
model_wrap
.
parameterization
=
self
.
model
.
parameterization
self
.
model_k
=
KSamplerX0Inpaint
(
self
.
model_wrap
)
self
.
device
=
device
self
.
device
=
device
if
scheduler
not
in
self
.
SCHEDULERS
:
if
scheduler
not
in
self
.
SCHEDULERS
:
scheduler
=
self
.
SCHEDULERS
[
0
]
scheduler
=
self
.
SCHEDULERS
[
0
]
...
@@ -316,6 +393,8 @@ class KSampler:
...
@@ -316,6 +393,8 @@ class KSampler:
for
c
in
negative
:
for
c
in
negative
:
create_cond_with_same_area_if_none
(
positive
,
c
)
create_cond_with_same_area_if_none
(
positive
,
c
)
apply_control_net_to_equal_area
(
positive
,
negative
)
if
self
.
model
.
model
.
diffusion_model
.
dtype
==
torch
.
float16
:
if
self
.
model
.
model
.
diffusion_model
.
dtype
==
torch
.
float16
:
precision_scope
=
torch
.
autocast
precision_scope
=
torch
.
autocast
else
:
else
:
...
...
comfy/sd.py
View file @
4efa67fa
...
@@ -6,6 +6,9 @@ import model_management
...
@@ -6,6 +6,9 @@ import model_management
from
ldm.util
import
instantiate_from_config
from
ldm.util
import
instantiate_from_config
from
ldm.models.autoencoder
import
AutoencoderKL
from
ldm.models.autoencoder
import
AutoencoderKL
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
from
.cldm
import
cldm
from
.
import
utils
def
load_torch_file
(
ckpt
):
def
load_torch_file
(
ckpt
):
if
ckpt
.
lower
().
endswith
(
".safetensors"
):
if
ckpt
.
lower
().
endswith
(
".safetensors"
):
...
@@ -323,6 +326,79 @@ class VAE:
...
@@ -323,6 +326,79 @@ class VAE:
samples
=
samples
.
cpu
()
samples
=
samples
.
cpu
()
return
samples
return
samples
class
ControlNet
:
def
__init__
(
self
,
control_model
):
self
.
control_model
=
control_model
self
.
cond_hint_original
=
None
self
.
cond_hint
=
None
def
get_control
(
self
,
x_noisy
,
t
,
cond_txt
):
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
self
.
cond_hint
=
utils
.
common_upscale
(
self
.
cond_hint_original
,
x_noisy
.
shape
[
3
]
*
8
,
x_noisy
.
shape
[
2
]
*
8
,
'nearest-exact'
,
"center"
).
to
(
x_noisy
.
device
)
print
(
"set cond_hint"
,
self
.
cond_hint
.
shape
)
control
=
self
.
control_model
(
x
=
x_noisy
,
hint
=
self
.
cond_hint
,
timesteps
=
t
,
context
=
cond_txt
)
return
control
def
set_cond_hint
(
self
,
cond_hint
):
self
.
cond_hint_original
=
cond_hint
return
self
def
cleanup
(
self
):
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
def
copy
(
self
):
c
=
ControlNet
(
self
.
control_model
)
c
.
cond_hint_original
=
self
.
cond_hint_original
return
c
def
load_controlnet
(
ckpt_path
):
controlnet_data
=
load_torch_file
(
ckpt_path
)
pth_key
=
'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth
=
False
sd2
=
False
key
=
'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
if
pth_key
in
controlnet_data
:
pth
=
True
key
=
pth_key
elif
key
in
controlnet_data
:
pass
else
:
print
(
"error checkpoint does not contain controlnet data"
,
ckpt_path
)
return
None
context_dim
=
controlnet_data
[
key
].
shape
[
1
]
control_model
=
cldm
.
ControlNet
(
image_size
=
32
,
in_channels
=
4
,
hint_channels
=
3
,
model_channels
=
320
,
attention_resolutions
=
[
4
,
2
,
1
],
num_res_blocks
=
2
,
channel_mult
=
[
1
,
2
,
4
,
4
],
num_heads
=
8
,
use_spatial_transformer
=
True
,
transformer_depth
=
1
,
context_dim
=
context_dim
,
use_checkpoint
=
True
,
legacy
=
False
)
if
pth
:
class
WeightsLoader
(
torch
.
nn
.
Module
):
pass
w
=
WeightsLoader
()
w
.
control_model
=
control_model
w
.
load_state_dict
(
controlnet_data
,
strict
=
False
)
else
:
control_model
.
load_state_dict
(
controlnet_data
,
strict
=
False
)
control
=
ControlNet
(
control_model
)
return
control
def
load_clip
(
ckpt_path
,
embedding_directory
=
None
):
def
load_clip
(
ckpt_path
,
embedding_directory
=
None
):
clip_data
=
load_torch_file
(
ckpt_path
)
clip_data
=
load_torch_file
(
ckpt_path
)
config
=
{}
config
=
{}
...
...
comfy/utils.py
0 → 100644
View file @
4efa67fa
import
torch
def
common_upscale
(
samples
,
width
,
height
,
upscale_method
,
crop
):
if
crop
==
"center"
:
old_width
=
samples
.
shape
[
3
]
old_height
=
samples
.
shape
[
2
]
old_aspect
=
old_width
/
old_height
new_aspect
=
width
/
height
x
=
0
y
=
0
if
old_aspect
>
new_aspect
:
x
=
round
((
old_width
-
old_width
*
(
new_aspect
/
old_aspect
))
/
2
)
elif
old_aspect
<
new_aspect
:
y
=
round
((
old_height
-
old_height
*
(
old_aspect
/
new_aspect
))
/
2
)
s
=
samples
[:,:,
y
:
old_height
-
y
,
x
:
old_width
-
x
]
else
:
s
=
samples
return
torch
.
nn
.
functional
.
interpolate
(
s
,
size
=
(
height
,
width
),
mode
=
upscale_method
)
nodes.py
View file @
4efa67fa
...
@@ -15,10 +15,12 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
...
@@ -15,10 +15,12 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
import
comfy.samplers
import
comfy.samplers
import
comfy.sd
import
comfy.sd
import
comfy.utils
import
model_management
import
model_management
supported_ckpt_extensions
=
[
'.ckpt'
]
supported_ckpt_extensions
=
[
'.ckpt'
,
'.pth'
]
supported_pt_extensions
=
[
'.ckpt'
,
'.pt'
,
'.bin'
]
supported_pt_extensions
=
[
'.ckpt'
,
'.pt'
,
'.bin'
,
'.pth'
]
try
:
try
:
import
safetensors.torch
import
safetensors.torch
supported_ckpt_extensions
+=
[
'.safetensors'
]
supported_ckpt_extensions
+=
[
'.safetensors'
]
...
@@ -77,12 +79,14 @@ class ConditioningSetArea:
...
@@ -77,12 +79,14 @@ class ConditioningSetArea:
CATEGORY
=
"conditioning"
CATEGORY
=
"conditioning"
def
append
(
self
,
conditioning
,
width
,
height
,
x
,
y
,
strength
,
min_sigma
=
0.0
,
max_sigma
=
99.0
):
def
append
(
self
,
conditioning
,
width
,
height
,
x
,
y
,
strength
,
min_sigma
=
0.0
,
max_sigma
=
99.0
):
c
=
copy
.
deepcopy
(
conditioning
)
c
=
[]
for
t
in
c
:
for
t
in
conditioning
:
t
[
1
][
'area'
]
=
(
height
//
8
,
width
//
8
,
y
//
8
,
x
//
8
)
n
=
[
t
[
0
],
t
[
1
].
copy
()]
t
[
1
][
'strength'
]
=
strength
n
[
1
][
'area'
]
=
(
height
//
8
,
width
//
8
,
y
//
8
,
x
//
8
)
t
[
1
][
'min_sigma'
]
=
min_sigma
n
[
1
][
'strength'
]
=
strength
t
[
1
][
'max_sigma'
]
=
max_sigma
n
[
1
][
'min_sigma'
]
=
min_sigma
n
[
1
][
'max_sigma'
]
=
max_sigma
c
.
append
(
n
)
return
(
c
,
)
return
(
c
,
)
class
VAEDecode
:
class
VAEDecode
:
...
@@ -134,7 +138,6 @@ class VAEEncodeForInpaint:
...
@@ -134,7 +138,6 @@ class VAEEncodeForInpaint:
CATEGORY
=
"latent/inpaint"
CATEGORY
=
"latent/inpaint"
def
encode
(
self
,
vae
,
pixels
,
mask
):
def
encode
(
self
,
vae
,
pixels
,
mask
):
print
(
pixels
.
shape
,
mask
.
shape
)
x
=
(
pixels
.
shape
[
1
]
//
64
)
*
64
x
=
(
pixels
.
shape
[
1
]
//
64
)
*
64
y
=
(
pixels
.
shape
[
2
]
//
64
)
*
64
y
=
(
pixels
.
shape
[
2
]
//
64
)
*
64
if
pixels
.
shape
[
1
]
!=
x
or
pixels
.
shape
[
2
]
!=
y
:
if
pixels
.
shape
[
1
]
!=
x
or
pixels
.
shape
[
2
]
!=
y
:
...
@@ -144,7 +147,6 @@ class VAEEncodeForInpaint:
...
@@ -144,7 +147,6 @@ class VAEEncodeForInpaint:
#shave off a few pixels to keep things seamless
#shave off a few pixels to keep things seamless
kernel_tensor
=
torch
.
ones
((
1
,
1
,
6
,
6
))
kernel_tensor
=
torch
.
ones
((
1
,
1
,
6
,
6
))
mask_erosion
=
torch
.
clamp
(
torch
.
nn
.
functional
.
conv2d
((
1.0
-
mask
.
round
())[
None
],
kernel_tensor
,
padding
=
3
),
0
,
1
)
mask_erosion
=
torch
.
clamp
(
torch
.
nn
.
functional
.
conv2d
((
1.0
-
mask
.
round
())[
None
],
kernel_tensor
,
padding
=
3
),
0
,
1
)
print
(
mask_erosion
.
shape
,
pixels
.
shape
)
for
i
in
range
(
3
):
for
i
in
range
(
3
):
pixels
[:,:,:,
i
]
-=
0.5
pixels
[:,:,:,
i
]
-=
0.5
pixels
[:,:,:,
i
]
*=
mask_erosion
[
0
][:
x
,:
y
].
round
()
pixels
[:,:,:,
i
]
*=
mask_erosion
[
0
][:
x
,:
y
].
round
()
...
@@ -211,6 +213,44 @@ class VAELoader:
...
@@ -211,6 +213,44 @@ class VAELoader:
vae
=
comfy
.
sd
.
VAE
(
ckpt_path
=
vae_path
)
vae
=
comfy
.
sd
.
VAE
(
ckpt_path
=
vae_path
)
return
(
vae
,)
return
(
vae
,)
class
ControlNetLoader
:
models_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"models"
)
controlnet_dir
=
os
.
path
.
join
(
models_dir
,
"controlnet"
)
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"control_net_name"
:
(
filter_files_extensions
(
recursive_search
(
s
.
controlnet_dir
),
supported_pt_extensions
),
)}}
RETURN_TYPES
=
(
"CONTROL_NET"
,)
FUNCTION
=
"load_controlnet"
CATEGORY
=
"loaders"
def
load_controlnet
(
self
,
control_net_name
):
controlnet_path
=
os
.
path
.
join
(
self
.
controlnet_dir
,
control_net_name
)
controlnet
=
comfy
.
sd
.
load_controlnet
(
controlnet_path
)
return
(
controlnet
,)
class
ControlNetApply
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"conditioning"
:
(
"CONDITIONING"
,
),
"control_net"
:
(
"CONTROL_NET"
,
),
"image"
:
(
"IMAGE"
,
)}}
RETURN_TYPES
=
(
"CONDITIONING"
,)
FUNCTION
=
"apply_controlnet"
CATEGORY
=
"conditioning"
def
apply_controlnet
(
self
,
conditioning
,
control_net
,
image
):
c
=
[]
control_hint
=
image
.
movedim
(
-
1
,
1
)
print
(
control_hint
.
shape
)
for
t
in
conditioning
:
n
=
[
t
[
0
],
t
[
1
].
copy
()]
n
[
1
][
'control'
]
=
control_net
.
copy
().
set_cond_hint
(
control_hint
)
c
.
append
(
n
)
return
(
c
,
)
class
CLIPLoader
:
class
CLIPLoader
:
models_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"models"
)
models_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"models"
)
clip_dir
=
os
.
path
.
join
(
models_dir
,
"clip"
)
clip_dir
=
os
.
path
.
join
(
models_dir
,
"clip"
)
...
@@ -248,22 +288,7 @@ class EmptyLatentImage:
...
@@ -248,22 +288,7 @@ class EmptyLatentImage:
latent
=
torch
.
zeros
([
batch_size
,
4
,
height
//
8
,
width
//
8
])
latent
=
torch
.
zeros
([
batch_size
,
4
,
height
//
8
,
width
//
8
])
return
({
"samples"
:
latent
},
)
return
({
"samples"
:
latent
},
)
def
common_upscale
(
samples
,
width
,
height
,
upscale_method
,
crop
):
if
crop
==
"center"
:
old_width
=
samples
.
shape
[
3
]
old_height
=
samples
.
shape
[
2
]
old_aspect
=
old_width
/
old_height
new_aspect
=
width
/
height
x
=
0
y
=
0
if
old_aspect
>
new_aspect
:
x
=
round
((
old_width
-
old_width
*
(
new_aspect
/
old_aspect
))
/
2
)
elif
old_aspect
<
new_aspect
:
y
=
round
((
old_height
-
old_height
*
(
old_aspect
/
new_aspect
))
/
2
)
s
=
samples
[:,:,
y
:
old_height
-
y
,
x
:
old_width
-
x
]
else
:
s
=
samples
return
torch
.
nn
.
functional
.
interpolate
(
s
,
size
=
(
height
,
width
),
mode
=
upscale_method
)
class
LatentUpscale
:
class
LatentUpscale
:
upscale_methods
=
[
"nearest-exact"
,
"bilinear"
,
"area"
]
upscale_methods
=
[
"nearest-exact"
,
"bilinear"
,
"area"
]
...
@@ -282,7 +307,7 @@ class LatentUpscale:
...
@@ -282,7 +307,7 @@ class LatentUpscale:
def
upscale
(
self
,
samples
,
upscale_method
,
width
,
height
,
crop
):
def
upscale
(
self
,
samples
,
upscale_method
,
width
,
height
,
crop
):
s
=
samples
.
copy
()
s
=
samples
.
copy
()
s
[
"samples"
]
=
common_upscale
(
samples
[
"samples"
],
width
//
8
,
height
//
8
,
upscale_method
,
crop
)
s
[
"samples"
]
=
comfy
.
utils
.
common_upscale
(
samples
[
"samples"
],
width
//
8
,
height
//
8
,
upscale_method
,
crop
)
return
(
s
,)
return
(
s
,)
class
LatentRotate
:
class
LatentRotate
:
...
@@ -461,19 +486,26 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
...
@@ -461,19 +486,26 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
positive_copy
=
[]
positive_copy
=
[]
negative_copy
=
[]
negative_copy
=
[]
control_nets
=
[]
for
p
in
positive
:
for
p
in
positive
:
t
=
p
[
0
]
t
=
p
[
0
]
if
t
.
shape
[
0
]
<
noise
.
shape
[
0
]:
if
t
.
shape
[
0
]
<
noise
.
shape
[
0
]:
t
=
torch
.
cat
([
t
]
*
noise
.
shape
[
0
])
t
=
torch
.
cat
([
t
]
*
noise
.
shape
[
0
])
t
=
t
.
to
(
device
)
t
=
t
.
to
(
device
)
if
'control'
in
p
[
1
]:
control_nets
+=
[
p
[
1
][
'control'
]]
positive_copy
+=
[[
t
]
+
p
[
1
:]]
positive_copy
+=
[[
t
]
+
p
[
1
:]]
for
n
in
negative
:
for
n
in
negative
:
t
=
n
[
0
]
t
=
n
[
0
]
if
t
.
shape
[
0
]
<
noise
.
shape
[
0
]:
if
t
.
shape
[
0
]
<
noise
.
shape
[
0
]:
t
=
torch
.
cat
([
t
]
*
noise
.
shape
[
0
])
t
=
torch
.
cat
([
t
]
*
noise
.
shape
[
0
])
t
=
t
.
to
(
device
)
t
=
t
.
to
(
device
)
if
'control'
in
p
[
1
]:
control_nets
+=
[
p
[
1
][
'control'
]]
negative_copy
+=
[[
t
]
+
n
[
1
:]]
negative_copy
+=
[[
t
]
+
n
[
1
:]]
model_management
.
load_controlnet_gpu
(
list
(
map
(
lambda
a
:
a
.
control_model
,
control_nets
)))
if
sampler_name
in
comfy
.
samplers
.
KSampler
.
SAMPLERS
:
if
sampler_name
in
comfy
.
samplers
.
KSampler
.
SAMPLERS
:
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
)
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
)
else
:
else
:
...
@@ -482,6 +514,9 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
...
@@ -482,6 +514,9 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
samples
=
sampler
.
sample
(
noise
,
positive_copy
,
negative_copy
,
cfg
=
cfg
,
latent_image
=
latent_image
,
start_step
=
start_step
,
last_step
=
last_step
,
force_full_denoise
=
force_full_denoise
,
denoise_mask
=
noise_mask
)
samples
=
sampler
.
sample
(
noise
,
positive_copy
,
negative_copy
,
cfg
=
cfg
,
latent_image
=
latent_image
,
start_step
=
start_step
,
last_step
=
last_step
,
force_full_denoise
=
force_full_denoise
,
denoise_mask
=
noise_mask
)
samples
=
samples
.
cpu
()
samples
=
samples
.
cpu
()
for
c
in
control_nets
:
c
.
cleanup
()
out
=
latent
.
copy
()
out
=
latent
.
copy
()
out
[
"samples"
]
=
samples
out
[
"samples"
]
=
samples
return
(
out
,
)
return
(
out
,
)
...
@@ -676,7 +711,7 @@ class ImageScale:
...
@@ -676,7 +711,7 @@ class ImageScale:
def
upscale
(
self
,
image
,
upscale_method
,
width
,
height
,
crop
):
def
upscale
(
self
,
image
,
upscale_method
,
width
,
height
,
crop
):
samples
=
image
.
movedim
(
-
1
,
1
)
samples
=
image
.
movedim
(
-
1
,
1
)
s
=
common_upscale
(
samples
,
width
,
height
,
upscale_method
,
crop
)
s
=
comfy
.
utils
.
common_upscale
(
samples
,
width
,
height
,
upscale_method
,
crop
)
s
=
s
.
movedim
(
1
,
-
1
)
s
=
s
.
movedim
(
1
,
-
1
)
return
(
s
,)
return
(
s
,)
...
@@ -704,6 +739,8 @@ NODE_CLASS_MAPPINGS = {
...
@@ -704,6 +739,8 @@ NODE_CLASS_MAPPINGS = {
"LatentCrop"
:
LatentCrop
,
"LatentCrop"
:
LatentCrop
,
"LoraLoader"
:
LoraLoader
,
"LoraLoader"
:
LoraLoader
,
"CLIPLoader"
:
CLIPLoader
,
"CLIPLoader"
:
CLIPLoader
,
"ControlNetApply"
:
ControlNetApply
,
"ControlNetLoader"
:
ControlNetLoader
,
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment