Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
fa66ece2
Unverified
Commit
fa66ece2
authored
Feb 17, 2023
by
Fannovel16
Committed by
GitHub
Feb 17, 2023
Browse files
Merge branch 'comfyanonymous:master' into master
parents
1c5fe809
1cf5b612
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
983 additions
and
96 deletions
+983
-96
comfy/cldm/cldm.py
comfy/cldm/cldm.py
+286
-0
comfy/extra_samplers/uni_pc.py
comfy/extra_samplers/uni_pc.py
+20
-7
comfy/ldm/models/diffusion/ddpm.py
comfy/ldm/models/diffusion/ddpm.py
+9
-9
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+8
-2
comfy/model_management.py
comfy/model_management.py
+16
-1
comfy/samplers.py
comfy/samplers.py
+186
-30
comfy/sd.py
comfy/sd.py
+81
-0
comfy/utils.py
comfy/utils.py
+18
-0
main.py
main.py
+4
-0
models/configs/v2-inpainting-inference.yaml
models/configs/v2-inpainting-inference.yaml
+158
-0
models/controlnet/put_controlnets_here
models/controlnet/put_controlnets_here
+0
-0
nodes.py
nodes.py
+197
-47
No files found.
comfy/cldm/cldm.py
0 → 100644
View file @
fa66ece2
#taken from: https://github.com/lllyasviel/ControlNet
#and modified
import
einops
import
torch
import
torch
as
th
import
torch.nn
as
nn
from
ldm.modules.diffusionmodules.util
import
(
conv_nd
,
linear
,
zero_module
,
timestep_embedding
,
)
from
einops
import
rearrange
,
repeat
from
torchvision.utils
import
make_grid
from
ldm.modules.attention
import
SpatialTransformer
from
ldm.modules.diffusionmodules.openaimodel
import
UNetModel
,
TimestepEmbedSequential
,
ResBlock
,
Downsample
,
AttentionBlock
from
ldm.models.diffusion.ddpm
import
LatentDiffusion
from
ldm.util
import
log_txt_as_img
,
exists
,
instantiate_from_config
class
ControlledUnetModel
(
UNetModel
):
#implemented in the ldm unet
pass
class
ControlNet
(
nn
.
Module
):
def
__init__
(
self
,
image_size
,
in_channels
,
model_channels
,
hint_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
dims
=
2
,
use_checkpoint
=
False
,
use_fp16
=
False
,
num_heads
=-
1
,
num_head_channels
=-
1
,
num_heads_upsample
=-
1
,
use_scale_shift_norm
=
False
,
resblock_updown
=
False
,
use_new_attention_order
=
False
,
use_spatial_transformer
=
False
,
# custom transformer support
transformer_depth
=
1
,
# custom transformer support
context_dim
=
None
,
# custom transformer support
n_embed
=
None
,
# custom support for prediction of discrete ids into codebook of first stage vq model
legacy
=
True
,
disable_self_attentions
=
None
,
num_attention_blocks
=
None
,
disable_middle_self_attn
=
False
,
use_linear_in_transformer
=
False
,
):
super
().
__init__
()
if
use_spatial_transformer
:
assert
context_dim
is
not
None
,
'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if
context_dim
is
not
None
:
assert
use_spatial_transformer
,
'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from
omegaconf.listconfig
import
ListConfig
if
type
(
context_dim
)
==
ListConfig
:
context_dim
=
list
(
context_dim
)
if
num_heads_upsample
==
-
1
:
num_heads_upsample
=
num_heads
if
num_heads
==
-
1
:
assert
num_head_channels
!=
-
1
,
'Either num_heads or num_head_channels has to be set'
if
num_head_channels
==
-
1
:
assert
num_heads
!=
-
1
,
'Either num_heads or num_head_channels has to be set'
self
.
dims
=
dims
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
if
isinstance
(
num_res_blocks
,
int
):
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
else
:
if
len
(
num_res_blocks
)
!=
len
(
channel_mult
):
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
self
.
num_res_blocks
=
num_res_blocks
if
disable_self_attentions
is
not
None
:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
if
num_attention_blocks
is
not
None
:
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
))))
print
(
f
"Constructor of UNetModel received num_attention_blocks=
{
num_attention_blocks
}
. "
f
"This option has LESS priority than attention_resolutions
{
attention_resolutions
}
, "
f
"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f
"attention will still not be set."
)
self
.
attention_resolutions
=
attention_resolutions
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
use_checkpoint
=
use_checkpoint
self
.
dtype
=
th
.
float16
if
use_fp16
else
th
.
float32
self
.
num_heads
=
num_heads
self
.
num_head_channels
=
num_head_channels
self
.
num_heads_upsample
=
num_heads_upsample
self
.
predict_codebook_ids
=
n_embed
is
not
None
time_embed_dim
=
model_channels
*
4
self
.
time_embed
=
nn
.
Sequential
(
linear
(
model_channels
,
time_embed_dim
),
nn
.
SiLU
(),
linear
(
time_embed_dim
,
time_embed_dim
),
)
self
.
input_blocks
=
nn
.
ModuleList
(
[
TimestepEmbedSequential
(
conv_nd
(
dims
,
in_channels
,
model_channels
,
3
,
padding
=
1
)
)
]
)
self
.
zero_convs
=
nn
.
ModuleList
([
self
.
make_zero_conv
(
model_channels
)])
self
.
input_hint_block
=
TimestepEmbedSequential
(
conv_nd
(
dims
,
hint_channels
,
16
,
3
,
padding
=
1
),
nn
.
SiLU
(),
conv_nd
(
dims
,
16
,
16
,
3
,
padding
=
1
),
nn
.
SiLU
(),
conv_nd
(
dims
,
16
,
32
,
3
,
padding
=
1
,
stride
=
2
),
nn
.
SiLU
(),
conv_nd
(
dims
,
32
,
32
,
3
,
padding
=
1
),
nn
.
SiLU
(),
conv_nd
(
dims
,
32
,
96
,
3
,
padding
=
1
,
stride
=
2
),
nn
.
SiLU
(),
conv_nd
(
dims
,
96
,
96
,
3
,
padding
=
1
),
nn
.
SiLU
(),
conv_nd
(
dims
,
96
,
256
,
3
,
padding
=
1
,
stride
=
2
),
nn
.
SiLU
(),
zero_module
(
conv_nd
(
dims
,
256
,
model_channels
,
3
,
padding
=
1
))
)
self
.
_feature_size
=
model_channels
input_block_chans
=
[
model_channels
]
ch
=
model_channels
ds
=
1
for
level
,
mult
in
enumerate
(
channel_mult
):
for
nr
in
range
(
self
.
num_res_blocks
[
level
]):
layers
=
[
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
mult
*
model_channels
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
if
exists
(
disable_self_attentions
):
disabled_sa
=
disable_self_attentions
[
level
]
else
:
disabled_sa
=
False
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
zero_convs
.
append
(
self
.
make_zero_conv
(
ch
))
self
.
_feature_size
+=
ch
input_block_chans
.
append
(
ch
)
if
level
!=
len
(
channel_mult
)
-
1
:
out_ch
=
ch
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
out_channels
=
out_ch
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
)
)
)
ch
=
out_ch
input_block_chans
.
append
(
ch
)
self
.
zero_convs
.
append
(
self
.
make_zero_conv
(
ch
))
ds
*=
2
self
.
_feature_size
+=
ch
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
num_heads
=
ch
//
num_head_channels
dim_head
=
num_head_channels
if
legacy
:
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
AttentionBlock
(
ch
,
use_checkpoint
=
use_checkpoint
,
num_heads
=
num_heads
,
num_head_channels
=
dim_head
,
use_new_attention_order
=
use_new_attention_order
,
)
if
not
use_spatial_transformer
else
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
),
ResBlock
(
ch
,
time_embed_dim
,
dropout
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
),
)
self
.
middle_block_out
=
self
.
make_zero_conv
(
ch
)
self
.
_feature_size
+=
ch
def
make_zero_conv
(
self
,
channels
):
return
TimestepEmbedSequential
(
zero_module
(
conv_nd
(
self
.
dims
,
channels
,
channels
,
1
,
padding
=
0
)))
def
forward
(
self
,
x
,
hint
,
timesteps
,
context
,
**
kwargs
):
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
,
repeat_only
=
False
)
emb
=
self
.
time_embed
(
t_emb
)
guided_hint
=
self
.
input_hint_block
(
hint
,
emb
,
context
)
outs
=
[]
h
=
x
.
type
(
self
.
dtype
)
for
module
,
zero_conv
in
zip
(
self
.
input_blocks
,
self
.
zero_convs
):
if
guided_hint
is
not
None
:
h
=
module
(
h
,
emb
,
context
)
h
+=
guided_hint
guided_hint
=
None
else
:
h
=
module
(
h
,
emb
,
context
)
outs
.
append
(
zero_conv
(
h
,
emb
,
context
))
h
=
self
.
middle_block
(
h
,
emb
,
context
)
outs
.
append
(
self
.
middle_block_out
(
h
,
emb
,
context
))
return
outs
comfy/extra_samplers/uni_pc.py
View file @
fa66ece2
...
@@ -358,7 +358,10 @@ class UniPC:
...
@@ -358,7 +358,10 @@ class UniPC:
predict_x0
=
True
,
predict_x0
=
True
,
thresholding
=
False
,
thresholding
=
False
,
max_val
=
1.
,
max_val
=
1.
,
variant
=
'bh1'
variant
=
'bh1'
,
noise_mask
=
None
,
masked_image
=
None
,
noise
=
None
,
):
):
"""Construct a UniPC.
"""Construct a UniPC.
...
@@ -370,6 +373,9 @@ class UniPC:
...
@@ -370,6 +373,9 @@ class UniPC:
self
.
predict_x0
=
predict_x0
self
.
predict_x0
=
predict_x0
self
.
thresholding
=
thresholding
self
.
thresholding
=
thresholding
self
.
max_val
=
max_val
self
.
max_val
=
max_val
self
.
noise_mask
=
noise_mask
self
.
masked_image
=
masked_image
self
.
noise
=
noise
def
dynamic_thresholding_fn
(
self
,
x0
,
t
=
None
):
def
dynamic_thresholding_fn
(
self
,
x0
,
t
=
None
):
"""
"""
...
@@ -386,6 +392,9 @@ class UniPC:
...
@@ -386,6 +392,9 @@ class UniPC:
"""
"""
Return the noise prediction model.
Return the noise prediction model.
"""
"""
if
self
.
noise_mask
is
not
None
:
return
self
.
model
(
x
,
t
)
*
self
.
noise_mask
else
:
return
self
.
model
(
x
,
t
)
return
self
.
model
(
x
,
t
)
def
data_prediction_fn
(
self
,
x
,
t
):
def
data_prediction_fn
(
self
,
x
,
t
):
...
@@ -401,6 +410,8 @@ class UniPC:
...
@@ -401,6 +410,8 @@ class UniPC:
s
=
torch
.
quantile
(
torch
.
abs
(
x0
).
reshape
((
x0
.
shape
[
0
],
-
1
)),
p
,
dim
=
1
)
s
=
torch
.
quantile
(
torch
.
abs
(
x0
).
reshape
((
x0
.
shape
[
0
],
-
1
)),
p
,
dim
=
1
)
s
=
expand_dims
(
torch
.
maximum
(
s
,
self
.
max_val
*
torch
.
ones_like
(
s
).
to
(
s
.
device
)),
dims
)
s
=
expand_dims
(
torch
.
maximum
(
s
,
self
.
max_val
*
torch
.
ones_like
(
s
).
to
(
s
.
device
)),
dims
)
x0
=
torch
.
clamp
(
x0
,
-
s
,
s
)
/
s
x0
=
torch
.
clamp
(
x0
,
-
s
,
s
)
/
s
if
self
.
noise_mask
is
not
None
:
x0
=
x0
*
self
.
noise_mask
+
(
1.
-
self
.
noise_mask
)
*
self
.
masked_image
return
x0
return
x0
def
model_fn
(
self
,
x
,
t
):
def
model_fn
(
self
,
x
,
t
):
...
@@ -713,6 +724,8 @@ class UniPC:
...
@@ -713,6 +724,8 @@ class UniPC:
assert
timesteps
.
shape
[
0
]
-
1
==
steps
assert
timesteps
.
shape
[
0
]
-
1
==
steps
# with torch.no_grad():
# with torch.no_grad():
for
step_index
in
trange
(
steps
):
for
step_index
in
trange
(
steps
):
if
self
.
noise_mask
is
not
None
:
x
=
x
*
self
.
noise_mask
+
(
1.
-
self
.
noise_mask
)
*
(
self
.
masked_image
*
self
.
noise_schedule
.
marginal_alpha
(
timesteps
[
step_index
])
+
self
.
noise
*
self
.
noise_schedule
.
marginal_std
(
timesteps
[
step_index
]))
if
step_index
==
0
:
if
step_index
==
0
:
vec_t
=
timesteps
[
0
].
expand
((
x
.
shape
[
0
]))
vec_t
=
timesteps
[
0
].
expand
((
x
.
shape
[
0
]))
model_prev_list
=
[
self
.
model_fn
(
x
,
vec_t
)]
model_prev_list
=
[
self
.
model_fn
(
x
,
vec_t
)]
...
@@ -820,7 +833,7 @@ def expand_dims(v, dims):
...
@@ -820,7 +833,7 @@ def expand_dims(v, dims):
def
sample_unipc
(
model
,
noise
,
image
,
sigmas
,
sampling_function
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
):
def
sample_unipc
(
model
,
noise
,
image
,
sigmas
,
sampling_function
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
,
noise_mask
=
None
):
to_zero
=
False
to_zero
=
False
if
sigmas
[
-
1
]
==
0
:
if
sigmas
[
-
1
]
==
0
:
timesteps
=
torch
.
nn
.
functional
.
interpolate
(
sigmas
[
None
,
None
,:
-
1
],
size
=
(
len
(
sigmas
),),
mode
=
'linear'
)[
0
][
0
]
timesteps
=
torch
.
nn
.
functional
.
interpolate
(
sigmas
[
None
,
None
,:
-
1
],
size
=
(
len
(
sigmas
),),
mode
=
'linear'
)[
0
][
0
]
...
@@ -843,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
...
@@ -843,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
device
=
noise
.
device
device
=
noise
.
device
if
model
.
inner_
model
.
parameterization
==
"v"
:
if
model
.
parameterization
==
"v"
:
model_type
=
"v"
model_type
=
"v"
else
:
else
:
model_type
=
"noise"
model_type
=
"noise"
model_fn
=
model_wrapper
(
model_fn
=
model_wrapper
(
model
.
inner_model
.
apply_model
,
model
.
inner_model
.
inner_model
.
apply_model
,
sampling_function
,
sampling_function
,
ns
,
ns
,
model_type
=
model_type
,
model_type
=
model_type
,
...
@@ -857,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
...
@@ -857,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None
model_kwargs
=
extra_args
,
model_kwargs
=
extra_args
,
)
)
uni_pc
=
UniPC
(
model_fn
,
ns
,
predict_x0
=
True
,
thresholding
=
False
)
uni_pc
=
UniPC
(
model_fn
,
ns
,
predict_x0
=
True
,
thresholding
=
False
,
noise_mask
=
noise_mask
,
masked_image
=
image
,
noise
=
noise
)
x
=
uni_pc
.
sample
(
img
,
timesteps
=
timesteps
,
skip_type
=
"time_uniform"
,
method
=
"multistep"
,
order
=
3
,
lower_order_final
=
True
)
x
=
uni_pc
.
sample
(
img
,
timesteps
=
timesteps
,
skip_type
=
"time_uniform"
,
method
=
"multistep"
,
order
=
3
,
lower_order_final
=
True
)
if
not
to_zero
:
if
not
to_zero
:
x
/=
ns
.
marginal_alpha
(
timesteps
[
-
1
])
x
/=
ns
.
marginal_alpha
(
timesteps
[
-
1
])
...
...
comfy/ldm/models/diffusion/ddpm.py
View file @
fa66ece2
...
@@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module):
...
@@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module):
self
.
conditioning_key
=
conditioning_key
self
.
conditioning_key
=
conditioning_key
assert
self
.
conditioning_key
in
[
None
,
'concat'
,
'crossattn'
,
'hybrid'
,
'adm'
,
'hybrid-adm'
,
'crossattn-adm'
]
assert
self
.
conditioning_key
in
[
None
,
'concat'
,
'crossattn'
,
'hybrid'
,
'adm'
,
'hybrid-adm'
,
'crossattn-adm'
]
def
forward
(
self
,
x
,
t
,
c_concat
:
list
=
None
,
c_crossattn
:
list
=
None
,
c_adm
=
None
):
def
forward
(
self
,
x
,
t
,
c_concat
:
list
=
None
,
c_crossattn
:
list
=
None
,
c_adm
=
None
,
control
=
None
):
if
self
.
conditioning_key
is
None
:
if
self
.
conditioning_key
is
None
:
out
=
self
.
diffusion_model
(
x
,
t
)
out
=
self
.
diffusion_model
(
x
,
t
,
control
=
control
)
elif
self
.
conditioning_key
==
'concat'
:
elif
self
.
conditioning_key
==
'concat'
:
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
out
=
self
.
diffusion_model
(
xc
,
t
)
out
=
self
.
diffusion_model
(
xc
,
t
,
control
=
control
)
elif
self
.
conditioning_key
==
'crossattn'
:
elif
self
.
conditioning_key
==
'crossattn'
:
if
not
self
.
sequential_cross_attn
:
if
not
self
.
sequential_cross_attn
:
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
...
@@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module):
...
@@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module):
# TorchScript changes names of the arguments
# TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce
# with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out
=
self
.
scripted_diffusion_model
(
x
,
t
,
cc
)
out
=
self
.
scripted_diffusion_model
(
x
,
t
,
cc
,
control
=
control
)
else
:
else
:
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
control
=
control
)
elif
self
.
conditioning_key
==
'hybrid'
:
elif
self
.
conditioning_key
==
'hybrid'
:
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
control
=
control
)
elif
self
.
conditioning_key
==
'hybrid-adm'
:
elif
self
.
conditioning_key
==
'hybrid-adm'
:
assert
c_adm
is
not
None
assert
c_adm
is
not
None
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
y
=
c_adm
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
)
elif
self
.
conditioning_key
==
'crossattn-adm'
:
elif
self
.
conditioning_key
==
'crossattn-adm'
:
assert
c_adm
is
not
None
assert
c_adm
is
not
None
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
y
=
c_adm
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
)
elif
self
.
conditioning_key
==
'adm'
:
elif
self
.
conditioning_key
==
'adm'
:
cc
=
c_crossattn
[
0
]
cc
=
c_crossattn
[
0
]
out
=
self
.
diffusion_model
(
x
,
t
,
y
=
cc
)
out
=
self
.
diffusion_model
(
x
,
t
,
y
=
cc
,
control
=
control
)
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
fa66ece2
...
@@ -753,7 +753,7 @@ class UNetModel(nn.Module):
...
@@ -753,7 +753,7 @@ class UNetModel(nn.Module):
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
control
=
None
,
**
kwargs
):
"""
"""
Apply the model to an input batch.
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param x: an [N x C x ...] Tensor of inputs.
...
@@ -778,8 +778,14 @@ class UNetModel(nn.Module):
...
@@ -778,8 +778,14 @@ class UNetModel(nn.Module):
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
)
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
,
context
)
h
=
self
.
middle_block
(
h
,
emb
,
context
)
if
control
is
not
None
:
h
+=
control
.
pop
()
for
module
in
self
.
output_blocks
:
for
module
in
self
.
output_blocks
:
h
=
th
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
hsp
=
hs
.
pop
()
if
control
is
not
None
:
hsp
+=
control
.
pop
()
h
=
th
.
cat
([
h
,
hsp
],
dim
=
1
)
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
)
h
=
h
.
type
(
x
.
dtype
)
h
=
h
.
type
(
x
.
dtype
)
if
self
.
predict_codebook_ids
:
if
self
.
predict_codebook_ids
:
...
...
comfy/model_management.py
View file @
fa66ece2
...
@@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s
...
@@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s
current_loaded_model
=
None
current_loaded_model
=
None
current_gpu_controlnets
=
[]
model_accelerated
=
False
model_accelerated
=
False
...
@@ -56,6 +56,7 @@ model_accelerated = False
...
@@ -56,6 +56,7 @@ model_accelerated = False
def
unload_model
():
def
unload_model
():
global
current_loaded_model
global
current_loaded_model
global
model_accelerated
global
model_accelerated
global
current_gpu_controlnets
if
current_loaded_model
is
not
None
:
if
current_loaded_model
is
not
None
:
if
model_accelerated
:
if
model_accelerated
:
accelerate
.
hooks
.
remove_hook_from_submodules
(
current_loaded_model
.
model
)
accelerate
.
hooks
.
remove_hook_from_submodules
(
current_loaded_model
.
model
)
...
@@ -64,6 +65,10 @@ def unload_model():
...
@@ -64,6 +65,10 @@ def unload_model():
current_loaded_model
.
model
.
cpu
()
current_loaded_model
.
model
.
cpu
()
current_loaded_model
.
unpatch_model
()
current_loaded_model
.
unpatch_model
()
current_loaded_model
=
None
current_loaded_model
=
None
if
len
(
current_gpu_controlnets
)
>
0
:
for
n
in
current_gpu_controlnets
:
n
.
cpu
()
current_gpu_controlnets
=
[]
def
load_model_gpu
(
model
):
def
load_model_gpu
(
model
):
...
@@ -95,6 +100,16 @@ def load_model_gpu(model):
...
@@ -95,6 +100,16 @@ def load_model_gpu(model):
model_accelerated
=
True
model_accelerated
=
True
return
current_loaded_model
return
current_loaded_model
def
load_controlnet_gpu
(
models
):
global
current_gpu_controlnets
for
m
in
current_gpu_controlnets
:
if
m
not
in
models
:
m
.
cpu
()
current_gpu_controlnets
=
[]
for
m
in
models
:
current_gpu_controlnets
.
append
(
m
.
cuda
())
def
get_free_memory
():
def
get_free_memory
():
dev
=
torch
.
cuda
.
current_device
()
dev
=
torch
.
cuda
.
current_device
()
...
...
comfy/samplers.py
View file @
fa66ece2
...
@@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module):
...
@@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module):
uncond
=
self
.
inner_model
(
x
,
sigma
,
cond
=
uncond
)
uncond
=
self
.
inner_model
(
x
,
sigma
,
cond
=
uncond
)
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
def
sampling_function
(
model_function
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
):
def
get_area_and_mult
(
cond
,
x_in
):
#The main sampling function shared by all the samplers
#Returns predicted noise
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
=
None
):
def
get_area_and_mult
(
cond
,
x_in
,
cond_concat_in
,
timestep_in
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
strength
=
1.0
min_sigma
=
0.0
max_sigma
=
999.0
if
'area'
in
cond
[
1
]:
if
'area'
in
cond
[
1
]:
area
=
cond
[
1
][
'area'
]
area
=
cond
[
1
][
'area'
]
if
'strength'
in
cond
[
1
]:
if
'strength'
in
cond
[
1
]:
...
@@ -48,9 +49,60 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
...
@@ -48,9 +49,60 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
if
(
area
[
1
]
+
area
[
3
])
<
x_in
.
shape
[
3
]:
if
(
area
[
1
]
+
area
[
3
])
<
x_in
.
shape
[
3
]:
for
t
in
range
(
rr
):
for
t
in
range
(
rr
):
mult
[:,:,:,
area
[
1
]
+
area
[
3
]
-
1
-
t
:
area
[
1
]
+
area
[
3
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
mult
[:,:,:,
area
[
1
]
+
area
[
3
]
-
1
-
t
:
area
[
1
]
+
area
[
3
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
return
(
input_x
,
mult
,
cond
[
0
],
area
)
conditionning
=
{}
conditionning
[
'c_crossattn'
]
=
cond
[
0
]
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
sigma
,
max_total_area
):
if
cond_concat_in
is
not
None
and
len
(
cond_concat_in
)
>
0
:
cropped
=
[]
for
x
in
cond_concat_in
:
cr
=
x
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
cropped
.
append
(
cr
)
conditionning
[
'c_concat'
]
=
torch
.
cat
(
cropped
,
dim
=
1
)
control
=
None
if
'control'
in
cond
[
1
]:
control
=
cond
[
1
][
'control'
]
return
(
input_x
,
mult
,
conditionning
,
area
,
control
)
def
cond_equal_size
(
c1
,
c2
):
if
c1
is
c2
:
return
True
if
c1
.
keys
()
!=
c2
.
keys
():
return
False
if
'c_crossattn'
in
c1
:
if
c1
[
'c_crossattn'
].
shape
!=
c2
[
'c_crossattn'
].
shape
:
return
False
if
'c_concat'
in
c1
:
if
c1
[
'c_concat'
].
shape
!=
c2
[
'c_concat'
].
shape
:
return
False
return
True
def
can_concat_cond
(
c1
,
c2
):
if
c1
[
0
].
shape
!=
c2
[
0
].
shape
:
return
False
if
(
c1
[
4
]
is
None
)
!=
(
c2
[
4
]
is
None
):
return
False
if
c1
[
4
]
is
not
None
:
if
c1
[
4
]
is
not
c2
[
4
]:
return
False
return
cond_equal_size
(
c1
[
2
],
c2
[
2
])
def
cond_cat
(
c_list
):
c_crossattn
=
[]
c_concat
=
[]
for
x
in
c_list
:
if
'c_crossattn'
in
x
:
c_crossattn
.
append
(
x
[
'c_crossattn'
])
if
'c_concat'
in
x
:
c_concat
.
append
(
x
[
'c_concat'
])
out
=
{}
if
len
(
c_crossattn
)
>
0
:
out
[
'c_crossattn'
]
=
[
torch
.
cat
(
c_crossattn
)]
if
len
(
c_concat
)
>
0
:
out
[
'c_concat'
]
=
[
torch
.
cat
(
c_concat
)]
return
out
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
timestep
,
max_total_area
,
cond_concat_in
):
out_cond
=
torch
.
zeros_like
(
x_in
)
out_cond
=
torch
.
zeros_like
(
x_in
)
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
...
@@ -62,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
...
@@ -62,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
to_run
=
[]
to_run
=
[]
for
x
in
cond
:
for
x
in
cond
:
p
=
get_area_and_mult
(
x
,
x_in
)
p
=
get_area_and_mult
(
x
,
x_in
,
cond_concat_in
,
timestep
)
if
p
is
None
:
if
p
is
None
:
continue
continue
to_run
+=
[(
p
,
COND
)]
to_run
+=
[(
p
,
COND
)]
for
x
in
uncond
:
for
x
in
uncond
:
p
=
get_area_and_mult
(
x
,
x_in
)
p
=
get_area_and_mult
(
x
,
x_in
,
cond_concat_in
,
timestep
)
if
p
is
None
:
if
p
is
None
:
continue
continue
...
@@ -79,8 +131,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
...
@@ -79,8 +131,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
first_shape
=
first
[
0
][
0
].
shape
first_shape
=
first
[
0
][
0
].
shape
to_batch_temp
=
[]
to_batch_temp
=
[]
for
x
in
range
(
len
(
to_run
)):
for
x
in
range
(
len
(
to_run
)):
if
to_run
[
x
][
0
][
0
].
shape
==
first_shape
:
if
can_concat_cond
(
to_run
[
x
][
0
],
first
[
0
]):
if
to_run
[
x
][
0
][
2
].
shape
==
first
[
0
][
2
].
shape
:
to_batch_temp
+=
[
x
]
to_batch_temp
+=
[
x
]
to_batch_temp
.
reverse
()
to_batch_temp
.
reverse
()
...
@@ -97,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
...
@@ -97,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
c
=
[]
c
=
[]
cond_or_uncond
=
[]
cond_or_uncond
=
[]
area
=
[]
area
=
[]
control
=
None
for
x
in
to_batch
:
for
x
in
to_batch
:
o
=
to_run
.
pop
(
x
)
o
=
to_run
.
pop
(
x
)
p
=
o
[
0
]
p
=
o
[
0
]
...
@@ -105,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
...
@@ -105,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
c
+=
[
p
[
2
]]
c
+=
[
p
[
2
]]
area
+=
[
p
[
3
]]
area
+=
[
p
[
3
]]
cond_or_uncond
+=
[
o
[
1
]]
cond_or_uncond
+=
[
o
[
1
]]
control
=
p
[
4
]
batch_chunks
=
len
(
cond_or_uncond
)
batch_chunks
=
len
(
cond_or_uncond
)
input_x
=
torch
.
cat
(
input_x
)
input_x
=
torch
.
cat
(
input_x
)
c
=
torch
.
cat
(
c
)
c
=
cond_
cat
(
c
)
sigma
_
=
torch
.
cat
([
sigma
]
*
batch_chunks
)
timestep
_
=
torch
.
cat
([
timestep
]
*
batch_chunks
)
output
=
model_function
(
input_x
,
sigma_
,
cond
=
c
).
chunk
(
batch_chunks
)
if
control
is
not
None
:
c
[
'control'
]
=
control
.
get_control
(
input_x
,
timestep_
,
c
[
'c_crossattn'
])
output
=
model_function
(
input_x
,
timestep_
,
cond
=
c
).
chunk
(
batch_chunks
)
del
input_x
del
input_x
for
o
in
range
(
batch_chunks
):
for
o
in
range
(
batch_chunks
):
...
@@ -132,15 +188,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
...
@@ -132,15 +188,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale):
max_total_area
=
model_management
.
maximum_batch_area
()
max_total_area
=
model_management
.
maximum_batch_area
()
cond
,
uncond
=
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x
,
sigma
,
max_total_area
)
cond
,
uncond
=
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x
,
timestep
,
max_total_area
,
cond_concat
)
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
class
CFGDenoiserComplex
(
torch
.
nn
.
Module
):
class
CompVisVDenoiser
(
k_diffusion_external
.
DiscreteVDDPMDenoiser
):
def
__init__
(
self
,
model
,
quantize
=
False
,
device
=
'cpu'
):
super
().
__init__
(
model
,
model
.
alphas_cumprod
,
quantize
=
quantize
)
def
get_v
(
self
,
x
,
t
,
cond
,
**
kwargs
):
return
self
.
inner_model
.
apply_model
(
x
,
t
,
cond
,
**
kwargs
)
class
CFGNoisePredictor
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
super
().
__init__
()
super
().
__init__
()
self
.
inner_model
=
model
self
.
inner_model
=
model
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
):
self
.
alphas_cumprod
=
model
.
alphas_cumprod
return
sampling_function
(
self
.
inner_model
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
)
def
apply_model
(
self
,
x
,
timestep
,
cond
,
uncond
,
cond_scale
,
cond_concat
=
None
):
out
=
sampling_function
(
self
.
inner_model
.
apply_model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
)
return
out
class
KSamplerX0Inpaint
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
inner_model
=
model
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
denoise_mask
,
cond_concat
=
None
):
if
denoise_mask
is
not
None
:
latent_mask
=
1.
-
denoise_mask
x
=
x
*
denoise_mask
+
(
self
.
latent_image
+
self
.
noise
*
sigma
)
*
latent_mask
out
=
self
.
inner_model
(
x
,
sigma
,
cond
=
cond
,
uncond
=
uncond
,
cond_scale
=
cond_scale
,
cond_concat
=
cond_concat
)
if
denoise_mask
is
not
None
:
out
*=
denoise_mask
if
denoise_mask
is
not
None
:
out
+=
self
.
latent_image
*
latent_mask
return
out
def
simple_scheduler
(
model
,
steps
):
def
simple_scheduler
(
model
,
steps
):
sigs
=
[]
sigs
=
[]
...
@@ -150,6 +234,15 @@ def simple_scheduler(model, steps):
...
@@ -150,6 +234,15 @@ def simple_scheduler(model, steps):
sigs
+=
[
0.0
]
sigs
+=
[
0.0
]
return
torch
.
FloatTensor
(
sigs
)
return
torch
.
FloatTensor
(
sigs
)
def
blank_inpaint_image_like
(
latent_image
):
blank_image
=
torch
.
ones_like
(
latent_image
)
# these are the values for "zero" in pixel space translated to latent space
blank_image
[:,
0
]
*=
0.8223
blank_image
[:,
1
]
*=
-
0.6876
blank_image
[:,
2
]
*=
0.6364
blank_image
[:,
3
]
*=
0.1380
return
blank_image
def
create_cond_with_same_area_if_none
(
conds
,
c
):
def
create_cond_with_same_area_if_none
(
conds
,
c
):
if
'area'
not
in
c
[
1
]:
if
'area'
not
in
c
[
1
]:
return
return
...
@@ -180,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c):
...
@@ -180,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c):
n
=
c
[
1
].
copy
()
n
=
c
[
1
].
copy
()
conds
+=
[[
smallest
[
0
],
n
]]
conds
+=
[[
smallest
[
0
],
n
]]
def
apply_control_net_to_equal_area
(
conds
,
uncond
):
cond_cnets
=
[]
cond_other
=
[]
uncond_cnets
=
[]
uncond_other
=
[]
for
t
in
range
(
len
(
conds
)):
x
=
conds
[
t
]
if
'area'
not
in
x
[
1
]:
if
'control'
in
x
[
1
]
and
x
[
1
][
'control'
]
is
not
None
:
cond_cnets
.
append
(
x
[
1
][
'control'
])
else
:
cond_other
.
append
((
x
,
t
))
for
t
in
range
(
len
(
uncond
)):
x
=
uncond
[
t
]
if
'area'
not
in
x
[
1
]:
if
'control'
in
x
[
1
]
and
x
[
1
][
'control'
]
is
not
None
:
uncond_cnets
.
append
(
x
[
1
][
'control'
])
else
:
uncond_other
.
append
((
x
,
t
))
if
len
(
uncond_cnets
)
>
0
:
return
for
x
in
range
(
len
(
cond_cnets
)):
temp
=
uncond_other
[
x
%
len
(
uncond_other
)]
o
=
temp
[
0
]
if
'control'
in
o
[
1
]
and
o
[
1
][
'control'
]
is
not
None
:
n
=
o
[
1
].
copy
()
n
[
'control'
]
=
cond_cnets
[
x
]
uncond
+=
[[
o
[
0
],
n
]]
else
:
n
=
o
[
1
].
copy
()
n
[
'control'
]
=
cond_cnets
[
x
]
uncond
[
temp
[
1
]]
=
[
o
[
0
],
n
]
class
KSampler
:
class
KSampler
:
SCHEDULERS
=
[
"karras"
,
"normal"
,
"simple"
]
SCHEDULERS
=
[
"karras"
,
"normal"
,
"simple"
]
SAMPLERS
=
[
"sample_euler"
,
"sample_euler_ancestral"
,
"sample_heun"
,
"sample_dpm_2"
,
"sample_dpm_2_ancestral"
,
SAMPLERS
=
[
"sample_euler"
,
"sample_euler_ancestral"
,
"sample_heun"
,
"sample_dpm_2"
,
"sample_dpm_2_ancestral"
,
...
@@ -188,11 +317,13 @@ class KSampler:
...
@@ -188,11 +317,13 @@ class KSampler:
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
):
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
):
self
.
model
=
model
self
.
model
=
model
self
.
model_denoise
=
CFGNoisePredictor
(
self
.
model
)
if
self
.
model
.
parameterization
==
"v"
:
if
self
.
model
.
parameterization
==
"v"
:
self
.
model_wrap
=
k_diffusion_external
.
CompVisVDenoiser
(
self
.
model
,
quantize
=
True
)
self
.
model_wrap
=
CompVisVDenoiser
(
self
.
model
_denoise
,
quantize
=
True
)
else
:
else
:
self
.
model_wrap
=
k_diffusion_external
.
CompVisDenoiser
(
self
.
model
,
quantize
=
True
)
self
.
model_wrap
=
k_diffusion_external
.
CompVisDenoiser
(
self
.
model_denoise
,
quantize
=
True
)
self
.
model_k
=
CFGDenoiserComplex
(
self
.
model_wrap
)
self
.
model_wrap
.
parameterization
=
self
.
model
.
parameterization
self
.
model_k
=
KSamplerX0Inpaint
(
self
.
model_wrap
)
self
.
device
=
device
self
.
device
=
device
if
scheduler
not
in
self
.
SCHEDULERS
:
if
scheduler
not
in
self
.
SCHEDULERS
:
scheduler
=
self
.
SCHEDULERS
[
0
]
scheduler
=
self
.
SCHEDULERS
[
0
]
...
@@ -200,8 +331,8 @@ class KSampler:
...
@@ -200,8 +331,8 @@ class KSampler:
sampler
=
self
.
SAMPLERS
[
0
]
sampler
=
self
.
SAMPLERS
[
0
]
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
self
.
sampler
=
sampler
self
.
sampler
=
sampler
self
.
sigma_min
=
float
(
self
.
model_wrap
.
sigma
s
[
0
]
)
self
.
sigma_min
=
float
(
self
.
model_wrap
.
sigma
_min
)
self
.
sigma_max
=
float
(
self
.
model_wrap
.
sigma
s
[
-
1
]
)
self
.
sigma_max
=
float
(
self
.
model_wrap
.
sigma
_max
)
self
.
set_steps
(
steps
,
denoise
)
self
.
set_steps
(
steps
,
denoise
)
def
_calculate_sigmas
(
self
,
steps
):
def
_calculate_sigmas
(
self
,
steps
):
...
@@ -235,7 +366,7 @@ class KSampler:
...
@@ -235,7 +366,7 @@ class KSampler:
self
.
sigmas
=
sigmas
[
-
(
steps
+
1
):]
self
.
sigmas
=
sigmas
[
-
(
steps
+
1
):]
def
sample
(
self
,
noise
,
positive
,
negative
,
cfg
,
latent_image
=
None
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
):
def
sample
(
self
,
noise
,
positive
,
negative
,
cfg
,
latent_image
=
None
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
denoise_mask
=
None
):
sigmas
=
self
.
sigmas
sigmas
=
self
.
sigmas
sigma_min
=
self
.
sigma_min
sigma_min
=
self
.
sigma_min
...
@@ -262,22 +393,47 @@ class KSampler:
...
@@ -262,22 +393,47 @@ class KSampler:
for
c
in
negative
:
for
c
in
negative
:
create_cond_with_same_area_if_none
(
positive
,
c
)
create_cond_with_same_area_if_none
(
positive
,
c
)
apply_control_net_to_equal_area
(
positive
,
negative
)
if
self
.
model
.
model
.
diffusion_model
.
dtype
==
torch
.
float16
:
if
self
.
model
.
model
.
diffusion_model
.
dtype
==
torch
.
float16
:
precision_scope
=
torch
.
autocast
precision_scope
=
torch
.
autocast
else
:
else
:
precision_scope
=
contextlib
.
nullcontext
precision_scope
=
contextlib
.
nullcontext
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
if
hasattr
(
self
.
model
,
'concat_keys'
):
cond_concat
=
[]
for
ck
in
self
.
model
.
concat_keys
:
if
denoise_mask
is
not
None
:
if
ck
==
"mask"
:
cond_concat
.
append
(
denoise_mask
[:,:
1
])
elif
ck
==
"masked_image"
:
cond_concat
.
append
(
latent_image
)
#NOTE: the latent_image should be masked by the mask in pixel space
else
:
if
ck
==
"mask"
:
cond_concat
.
append
(
torch
.
ones_like
(
noise
)[:,:
1
])
elif
ck
==
"masked_image"
:
cond_concat
.
append
(
blank_inpaint_image_like
(
noise
))
extra_args
[
"cond_concat"
]
=
cond_concat
with
precision_scope
(
self
.
device
):
with
precision_scope
(
self
.
device
):
if
self
.
sampler
==
"uni_pc"
:
if
self
.
sampler
==
"uni_pc"
:
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
)
samples
=
uni_pc
.
sample_unipc
(
self
.
model_wrap
,
noise
,
latent_image
,
sigmas
,
sampling_function
=
sampling_function
,
extra_args
=
extra_args
,
noise_mask
=
denoise_mask
)
else
:
else
:
noise
*=
sigmas
[
0
]
extra_args
[
"denoise_mask"
]
=
denoise_mask
self
.
model_k
.
latent_image
=
latent_image
self
.
model_k
.
noise
=
noise
noise
=
noise
*
sigmas
[
0
]
if
latent_image
is
not
None
:
if
latent_image
is
not
None
:
noise
+=
latent_image
noise
+=
latent_image
if
self
.
sampler
==
"sample_dpm_fast"
:
if
self
.
sampler
==
"sample_dpm_fast"
:
samples
=
k_diffusion_sampling
.
sample_dpm_fast
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
self
.
steps
,
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
)
samples
=
k_diffusion_sampling
.
sample_dpm_fast
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
self
.
steps
,
extra_args
=
extra_args
)
elif
self
.
sampler
==
"sample_dpm_adaptive"
:
elif
self
.
sampler
==
"sample_dpm_adaptive"
:
samples
=
k_diffusion_sampling
.
sample_dpm_adaptive
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
)
samples
=
k_diffusion_sampling
.
sample_dpm_adaptive
(
self
.
model_k
,
noise
,
sigma_min
,
sigmas
[
0
],
extra_args
=
extra_args
)
else
:
else
:
samples
=
getattr
(
k_diffusion_sampling
,
self
.
sampler
)(
self
.
model_k
,
noise
,
sigmas
,
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
})
samples
=
getattr
(
k_diffusion_sampling
,
self
.
sampler
)(
self
.
model_k
,
noise
,
sigmas
,
extra_args
=
extra_args
)
return
samples
.
to
(
torch
.
float32
)
return
samples
.
to
(
torch
.
float32
)
comfy/sd.py
View file @
fa66ece2
...
@@ -6,6 +6,9 @@ import model_management
...
@@ -6,6 +6,9 @@ import model_management
from
ldm.util
import
instantiate_from_config
from
ldm.util
import
instantiate_from_config
from
ldm.models.autoencoder
import
AutoencoderKL
from
ldm.models.autoencoder
import
AutoencoderKL
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
from
.cldm
import
cldm
from
.
import
utils
def
load_torch_file
(
ckpt
):
def
load_torch_file
(
ckpt
):
if
ckpt
.
lower
().
endswith
(
".safetensors"
):
if
ckpt
.
lower
().
endswith
(
".safetensors"
):
...
@@ -323,6 +326,84 @@ class VAE:
...
@@ -323,6 +326,84 @@ class VAE:
samples
=
samples
.
cpu
()
samples
=
samples
.
cpu
()
return
samples
return
samples
class
ControlNet
:
def
__init__
(
self
,
control_model
):
self
.
control_model
=
control_model
self
.
cond_hint_original
=
None
self
.
cond_hint
=
None
self
.
strength
=
1.0
def
get_control
(
self
,
x_noisy
,
t
,
cond_txt
):
if
self
.
cond_hint
is
None
or
x_noisy
.
shape
[
2
]
*
8
!=
self
.
cond_hint
.
shape
[
2
]
or
x_noisy
.
shape
[
3
]
*
8
!=
self
.
cond_hint
.
shape
[
3
]:
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
self
.
cond_hint
=
utils
.
common_upscale
(
self
.
cond_hint_original
,
x_noisy
.
shape
[
3
]
*
8
,
x_noisy
.
shape
[
2
]
*
8
,
'nearest-exact'
,
"center"
).
to
(
x_noisy
.
device
)
print
(
"set cond_hint"
,
self
.
cond_hint
.
shape
)
control
=
self
.
control_model
(
x
=
x_noisy
,
hint
=
self
.
cond_hint
,
timesteps
=
t
,
context
=
cond_txt
)
for
x
in
control
:
x
*=
self
.
strength
return
control
def
set_cond_hint
(
self
,
cond_hint
,
strength
=
1.0
):
self
.
cond_hint_original
=
cond_hint
self
.
strength
=
strength
return
self
def
cleanup
(
self
):
if
self
.
cond_hint
is
not
None
:
del
self
.
cond_hint
self
.
cond_hint
=
None
def
copy
(
self
):
c
=
ControlNet
(
self
.
control_model
)
c
.
cond_hint_original
=
self
.
cond_hint_original
c
.
strength
=
self
.
strength
return
c
def
load_controlnet
(
ckpt_path
):
controlnet_data
=
load_torch_file
(
ckpt_path
)
pth_key
=
'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
pth
=
False
sd2
=
False
key
=
'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'
if
pth_key
in
controlnet_data
:
pth
=
True
key
=
pth_key
elif
key
in
controlnet_data
:
pass
else
:
print
(
"error checkpoint does not contain controlnet data"
,
ckpt_path
)
return
None
context_dim
=
controlnet_data
[
key
].
shape
[
1
]
control_model
=
cldm
.
ControlNet
(
image_size
=
32
,
in_channels
=
4
,
hint_channels
=
3
,
model_channels
=
320
,
attention_resolutions
=
[
4
,
2
,
1
],
num_res_blocks
=
2
,
channel_mult
=
[
1
,
2
,
4
,
4
],
num_heads
=
8
,
use_spatial_transformer
=
True
,
transformer_depth
=
1
,
context_dim
=
context_dim
,
use_checkpoint
=
True
,
legacy
=
False
)
if
pth
:
class
WeightsLoader
(
torch
.
nn
.
Module
):
pass
w
=
WeightsLoader
()
w
.
control_model
=
control_model
w
.
load_state_dict
(
controlnet_data
,
strict
=
False
)
else
:
control_model
.
load_state_dict
(
controlnet_data
,
strict
=
False
)
control
=
ControlNet
(
control_model
)
return
control
def
load_clip
(
ckpt_path
,
embedding_directory
=
None
):
def
load_clip
(
ckpt_path
,
embedding_directory
=
None
):
clip_data
=
load_torch_file
(
ckpt_path
)
clip_data
=
load_torch_file
(
ckpt_path
)
config
=
{}
config
=
{}
...
...
comfy/utils.py
0 → 100644
View file @
fa66ece2
import
torch
def
common_upscale
(
samples
,
width
,
height
,
upscale_method
,
crop
):
if
crop
==
"center"
:
old_width
=
samples
.
shape
[
3
]
old_height
=
samples
.
shape
[
2
]
old_aspect
=
old_width
/
old_height
new_aspect
=
width
/
height
x
=
0
y
=
0
if
old_aspect
>
new_aspect
:
x
=
round
((
old_width
-
old_width
*
(
new_aspect
/
old_aspect
))
/
2
)
elif
old_aspect
<
new_aspect
:
y
=
round
((
old_height
-
old_height
*
(
old_aspect
/
new_aspect
))
/
2
)
s
=
samples
[:,:,
y
:
old_height
-
y
,
x
:
old_width
-
x
]
else
:
s
=
samples
return
torch
.
nn
.
functional
.
interpolate
(
s
,
size
=
(
height
,
width
),
mode
=
upscale_method
)
main.py
View file @
fa66ece2
...
@@ -7,6 +7,10 @@ import heapq
...
@@ -7,6 +7,10 @@ import heapq
import
traceback
import
traceback
import
asyncio
import
asyncio
if
os
.
name
==
"nt"
:
import
logging
logging
.
getLogger
(
"xformers"
).
addFilter
(
lambda
record
:
'A matching Triton is not available'
not
in
record
.
getMessage
())
try
:
try
:
import
aiohttp
import
aiohttp
from
aiohttp
import
web
from
aiohttp
import
web
...
...
models/configs/v2-inpainting-inference.yaml
0 → 100644
View file @
fa66ece2
model
:
base_learning_rate
:
5.0e-05
target
:
ldm.models.diffusion.ddpm.LatentInpaintDiffusion
params
:
linear_start
:
0.00085
linear_end
:
0.0120
num_timesteps_cond
:
1
log_every_t
:
200
timesteps
:
1000
first_stage_key
:
"
jpg"
cond_stage_key
:
"
txt"
image_size
:
64
channels
:
4
cond_stage_trainable
:
false
conditioning_key
:
hybrid
scale_factor
:
0.18215
monitor
:
val/loss_simple_ema
finetune_keys
:
null
use_ema
:
False
unet_config
:
target
:
ldm.modules.diffusionmodules.openaimodel.UNetModel
params
:
use_checkpoint
:
True
image_size
:
32
# unused
in_channels
:
9
out_channels
:
4
model_channels
:
320
attention_resolutions
:
[
4
,
2
,
1
]
num_res_blocks
:
2
channel_mult
:
[
1
,
2
,
4
,
4
]
num_head_channels
:
64
# need to fix for flash-attn
use_spatial_transformer
:
True
use_linear_in_transformer
:
True
transformer_depth
:
1
context_dim
:
1024
legacy
:
False
first_stage_config
:
target
:
ldm.models.autoencoder.AutoencoderKL
params
:
embed_dim
:
4
monitor
:
val/rec_loss
ddconfig
:
#attn_type: "vanilla-xformers"
double_z
:
true
z_channels
:
4
resolution
:
256
in_channels
:
3
out_ch
:
3
ch
:
128
ch_mult
:
-
1
-
2
-
4
-
4
num_res_blocks
:
2
attn_resolutions
:
[
]
dropout
:
0.0
lossconfig
:
target
:
torch.nn.Identity
cond_stage_config
:
target
:
ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params
:
freeze
:
True
layer
:
"
penultimate"
data
:
target
:
ldm.data.laion.WebDataModuleFromConfig
params
:
tar_base
:
null
# for concat as in LAION-A
p_unsafe_threshold
:
0.1
filter_word_list
:
"
data/filters.yaml"
max_pwatermark
:
0.45
batch_size
:
8
num_workers
:
6
multinode
:
True
min_size
:
512
train
:
shards
:
-
"
pipe:aws
s3
cp
s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar
-"
-
"
pipe:aws
s3
cp
s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar
-"
-
"
pipe:aws
s3
cp
s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar
-"
-
"
pipe:aws
s3
cp
s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar
-"
-
"
pipe:aws
s3
cp
s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar
-"
#{00000-94333}.tar"
shuffle
:
10000
image_key
:
jpg
image_transforms
:
-
target
:
torchvision.transforms.Resize
params
:
size
:
512
interpolation
:
3
-
target
:
torchvision.transforms.RandomCrop
params
:
size
:
512
postprocess
:
target
:
ldm.data.laion.AddMask
params
:
mode
:
"
512train-large"
p_drop
:
0.25
# NOTE use enough shards to avoid empty validation loops in workers
validation
:
shards
:
-
"
pipe:aws
s3
cp
s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar
-
"
shuffle
:
0
image_key
:
jpg
image_transforms
:
-
target
:
torchvision.transforms.Resize
params
:
size
:
512
interpolation
:
3
-
target
:
torchvision.transforms.CenterCrop
params
:
size
:
512
postprocess
:
target
:
ldm.data.laion.AddMask
params
:
mode
:
"
512train-large"
p_drop
:
0.25
lightning
:
find_unused_parameters
:
True
modelcheckpoint
:
params
:
every_n_train_steps
:
5000
callbacks
:
metrics_over_trainsteps_checkpoint
:
params
:
every_n_train_steps
:
10000
image_logger
:
target
:
main.ImageLogger
params
:
enable_autocast
:
False
disabled
:
False
batch_frequency
:
1000
max_images
:
4
increase_log_steps
:
False
log_first_step
:
False
log_images_kwargs
:
use_ema_scope
:
False
inpaint
:
False
plot_progressive_rows
:
False
plot_diffusion_rows
:
False
N
:
4
unconditional_guidance_scale
:
5.0
unconditional_guidance_label
:
[
"
"
]
ddim_steps
:
50
# todo check these out for depth2img,
ddim_eta
:
0.0
# todo check these out for depth2img,
trainer
:
benchmark
:
True
val_check_interval
:
5000000
num_sanity_val_steps
:
0
accumulate_grad_batches
:
1
models/controlnet/put_controlnets_here
0 → 100644
View file @
fa66ece2
nodes.py
View file @
fa66ece2
...
@@ -15,11 +15,13 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
...
@@ -15,11 +15,13 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy"))
import
comfy.samplers
import
comfy.samplers
import
comfy.sd
import
comfy.sd
import
comfy.utils
import
model_management
import
model_management
import
importlib
import
importlib
supported_ckpt_extensions
=
[
'.ckpt'
]
supported_ckpt_extensions
=
[
'.ckpt'
,
'.pth'
]
supported_pt_extensions
=
[
'.ckpt'
,
'.pt'
,
'.bin'
]
supported_pt_extensions
=
[
'.ckpt'
,
'.pt'
,
'.bin'
,
'.pth'
]
try
:
try
:
import
safetensors.torch
import
safetensors.torch
supported_ckpt_extensions
+=
[
'.safetensors'
]
supported_ckpt_extensions
+=
[
'.safetensors'
]
...
@@ -78,12 +80,14 @@ class ConditioningSetArea:
...
@@ -78,12 +80,14 @@ class ConditioningSetArea:
CATEGORY
=
"conditioning"
CATEGORY
=
"conditioning"
def
append
(
self
,
conditioning
,
width
,
height
,
x
,
y
,
strength
,
min_sigma
=
0.0
,
max_sigma
=
99.0
):
def
append
(
self
,
conditioning
,
width
,
height
,
x
,
y
,
strength
,
min_sigma
=
0.0
,
max_sigma
=
99.0
):
c
=
copy
.
deepcopy
(
conditioning
)
c
=
[]
for
t
in
c
:
for
t
in
conditioning
:
t
[
1
][
'area'
]
=
(
height
//
8
,
width
//
8
,
y
//
8
,
x
//
8
)
n
=
[
t
[
0
],
t
[
1
].
copy
()]
t
[
1
][
'strength'
]
=
strength
n
[
1
][
'area'
]
=
(
height
//
8
,
width
//
8
,
y
//
8
,
x
//
8
)
t
[
1
][
'min_sigma'
]
=
min_sigma
n
[
1
][
'strength'
]
=
strength
t
[
1
][
'max_sigma'
]
=
max_sigma
n
[
1
][
'min_sigma'
]
=
min_sigma
n
[
1
][
'max_sigma'
]
=
max_sigma
c
.
append
(
n
)
return
(
c
,
)
return
(
c
,
)
class
VAEDecode
:
class
VAEDecode
:
...
@@ -99,7 +103,7 @@ class VAEDecode:
...
@@ -99,7 +103,7 @@ class VAEDecode:
CATEGORY
=
"latent"
CATEGORY
=
"latent"
def
decode
(
self
,
vae
,
samples
):
def
decode
(
self
,
vae
,
samples
):
return
(
vae
.
decode
(
samples
),
)
return
(
vae
.
decode
(
samples
[
"samples"
]
),
)
class
VAEEncode
:
class
VAEEncode
:
def
__init__
(
self
,
device
=
"cpu"
):
def
__init__
(
self
,
device
=
"cpu"
):
...
@@ -118,7 +122,39 @@ class VAEEncode:
...
@@ -118,7 +122,39 @@ class VAEEncode:
y
=
(
pixels
.
shape
[
2
]
//
64
)
*
64
y
=
(
pixels
.
shape
[
2
]
//
64
)
*
64
if
pixels
.
shape
[
1
]
!=
x
or
pixels
.
shape
[
2
]
!=
y
:
if
pixels
.
shape
[
1
]
!=
x
or
pixels
.
shape
[
2
]
!=
y
:
pixels
=
pixels
[:,:
x
,:
y
,:]
pixels
=
pixels
[:,:
x
,:
y
,:]
return
(
vae
.
encode
(
pixels
),
)
t
=
vae
.
encode
(
pixels
[:,:,:,:
3
])
return
({
"samples"
:
t
},
)
class
VAEEncodeForInpaint
:
def
__init__
(
self
,
device
=
"cpu"
):
self
.
device
=
device
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"pixels"
:
(
"IMAGE"
,
),
"vae"
:
(
"VAE"
,
),
"mask"
:
(
"MASK"
,
)}}
RETURN_TYPES
=
(
"LATENT"
,)
FUNCTION
=
"encode"
CATEGORY
=
"latent/inpaint"
def
encode
(
self
,
vae
,
pixels
,
mask
):
x
=
(
pixels
.
shape
[
1
]
//
64
)
*
64
y
=
(
pixels
.
shape
[
2
]
//
64
)
*
64
if
pixels
.
shape
[
1
]
!=
x
or
pixels
.
shape
[
2
]
!=
y
:
pixels
=
pixels
[:,:
x
,:
y
,:]
mask
=
mask
[:
x
,:
y
]
#shave off a few pixels to keep things seamless
kernel_tensor
=
torch
.
ones
((
1
,
1
,
6
,
6
))
mask_erosion
=
torch
.
clamp
(
torch
.
nn
.
functional
.
conv2d
((
1.0
-
mask
.
round
())[
None
],
kernel_tensor
,
padding
=
3
),
0
,
1
)
for
i
in
range
(
3
):
pixels
[:,:,:,
i
]
-=
0.5
pixels
[:,:,:,
i
]
*=
mask_erosion
[
0
][:
x
,:
y
].
round
()
pixels
[:,:,:,
i
]
+=
0.5
t
=
vae
.
encode
(
pixels
)
return
({
"samples"
:
t
,
"noise_mask"
:
mask
},
)
class
CheckpointLoader
:
class
CheckpointLoader
:
models_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"models"
)
models_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"models"
)
...
@@ -178,6 +214,48 @@ class VAELoader:
...
@@ -178,6 +214,48 @@ class VAELoader:
vae
=
comfy
.
sd
.
VAE
(
ckpt_path
=
vae_path
)
vae
=
comfy
.
sd
.
VAE
(
ckpt_path
=
vae_path
)
return
(
vae
,)
return
(
vae
,)
class
ControlNetLoader
:
models_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"models"
)
controlnet_dir
=
os
.
path
.
join
(
models_dir
,
"controlnet"
)
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"control_net_name"
:
(
filter_files_extensions
(
recursive_search
(
s
.
controlnet_dir
),
supported_pt_extensions
),
)}}
RETURN_TYPES
=
(
"CONTROL_NET"
,)
FUNCTION
=
"load_controlnet"
CATEGORY
=
"loaders"
def
load_controlnet
(
self
,
control_net_name
):
controlnet_path
=
os
.
path
.
join
(
self
.
controlnet_dir
,
control_net_name
)
controlnet
=
comfy
.
sd
.
load_controlnet
(
controlnet_path
)
return
(
controlnet
,)
class
ControlNetApply
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"conditioning"
:
(
"CONDITIONING"
,
),
"control_net"
:
(
"CONTROL_NET"
,
),
"image"
:
(
"IMAGE"
,
),
"strength"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
})
}}
RETURN_TYPES
=
(
"CONDITIONING"
,)
FUNCTION
=
"apply_controlnet"
CATEGORY
=
"conditioning"
def
apply_controlnet
(
self
,
conditioning
,
control_net
,
image
,
strength
):
c
=
[]
control_hint
=
image
.
movedim
(
-
1
,
1
)
print
(
control_hint
.
shape
)
for
t
in
conditioning
:
n
=
[
t
[
0
],
t
[
1
].
copy
()]
n
[
1
][
'control'
]
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
)
c
.
append
(
n
)
return
(
c
,
)
class
CLIPLoader
:
class
CLIPLoader
:
models_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"models"
)
models_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"models"
)
clip_dir
=
os
.
path
.
join
(
models_dir
,
"clip"
)
clip_dir
=
os
.
path
.
join
(
models_dir
,
"clip"
)
...
@@ -213,24 +291,9 @@ class EmptyLatentImage:
...
@@ -213,24 +291,9 @@ class EmptyLatentImage:
def
generate
(
self
,
width
,
height
,
batch_size
=
1
):
def
generate
(
self
,
width
,
height
,
batch_size
=
1
):
latent
=
torch
.
zeros
([
batch_size
,
4
,
height
//
8
,
width
//
8
])
latent
=
torch
.
zeros
([
batch_size
,
4
,
height
//
8
,
width
//
8
])
return
(
latent
,
)
return
({
"samples"
:
latent
},
)
def
common_upscale
(
samples
,
width
,
height
,
upscale_method
,
crop
):
if
crop
==
"center"
:
old_width
=
samples
.
shape
[
3
]
old_height
=
samples
.
shape
[
2
]
old_aspect
=
old_width
/
old_height
new_aspect
=
width
/
height
x
=
0
y
=
0
if
old_aspect
>
new_aspect
:
x
=
round
((
old_width
-
old_width
*
(
new_aspect
/
old_aspect
))
/
2
)
elif
old_aspect
<
new_aspect
:
y
=
round
((
old_height
-
old_height
*
(
old_aspect
/
new_aspect
))
/
2
)
s
=
samples
[:,:,
y
:
old_height
-
y
,
x
:
old_width
-
x
]
else
:
s
=
samples
return
torch
.
nn
.
functional
.
interpolate
(
s
,
size
=
(
height
,
width
),
mode
=
upscale_method
)
class
LatentUpscale
:
class
LatentUpscale
:
upscale_methods
=
[
"nearest-exact"
,
"bilinear"
,
"area"
]
upscale_methods
=
[
"nearest-exact"
,
"bilinear"
,
"area"
]
...
@@ -248,7 +311,8 @@ class LatentUpscale:
...
@@ -248,7 +311,8 @@ class LatentUpscale:
CATEGORY
=
"latent"
CATEGORY
=
"latent"
def
upscale
(
self
,
samples
,
upscale_method
,
width
,
height
,
crop
):
def
upscale
(
self
,
samples
,
upscale_method
,
width
,
height
,
crop
):
s
=
common_upscale
(
samples
,
width
//
8
,
height
//
8
,
upscale_method
,
crop
)
s
=
samples
.
copy
()
s
[
"samples"
]
=
comfy
.
utils
.
common_upscale
(
samples
[
"samples"
],
width
//
8
,
height
//
8
,
upscale_method
,
crop
)
return
(
s
,)
return
(
s
,)
class
LatentRotate
:
class
LatentRotate
:
...
@@ -263,6 +327,7 @@ class LatentRotate:
...
@@ -263,6 +327,7 @@ class LatentRotate:
CATEGORY
=
"latent"
CATEGORY
=
"latent"
def
rotate
(
self
,
samples
,
rotation
):
def
rotate
(
self
,
samples
,
rotation
):
s
=
samples
.
copy
()
rotate_by
=
0
rotate_by
=
0
if
rotation
.
startswith
(
"90"
):
if
rotation
.
startswith
(
"90"
):
rotate_by
=
1
rotate_by
=
1
...
@@ -271,7 +336,7 @@ class LatentRotate:
...
@@ -271,7 +336,7 @@ class LatentRotate:
elif
rotation
.
startswith
(
"270"
):
elif
rotation
.
startswith
(
"270"
):
rotate_by
=
3
rotate_by
=
3
s
=
torch
.
rot90
(
samples
,
k
=
rotate_by
,
dims
=
[
3
,
2
])
s
[
"samples"
]
=
torch
.
rot90
(
samples
[
"samples"
]
,
k
=
rotate_by
,
dims
=
[
3
,
2
])
return
(
s
,)
return
(
s
,)
class
LatentFlip
:
class
LatentFlip
:
...
@@ -286,12 +351,11 @@ class LatentFlip:
...
@@ -286,12 +351,11 @@ class LatentFlip:
CATEGORY
=
"latent"
CATEGORY
=
"latent"
def
flip
(
self
,
samples
,
flip_method
):
def
flip
(
self
,
samples
,
flip_method
):
s
=
samples
.
copy
()
if
flip_method
.
startswith
(
"x"
):
if
flip_method
.
startswith
(
"x"
):
s
=
torch
.
flip
(
samples
,
dims
=
[
2
])
s
[
"samples"
]
=
torch
.
flip
(
samples
[
"samples"
]
,
dims
=
[
2
])
elif
flip_method
.
startswith
(
"y"
):
elif
flip_method
.
startswith
(
"y"
):
s
=
torch
.
flip
(
samples
,
dims
=
[
3
])
s
[
"samples"
]
=
torch
.
flip
(
samples
[
"samples"
],
dims
=
[
3
])
else
:
s
=
samples
return
(
s
,)
return
(
s
,)
...
@@ -313,12 +377,15 @@ class LatentComposite:
...
@@ -313,12 +377,15 @@ class LatentComposite:
x
=
x
//
8
x
=
x
//
8
y
=
y
//
8
y
=
y
//
8
feather
=
feather
//
8
feather
=
feather
//
8
s
=
samples_to
.
clone
()
samples_out
=
samples_to
.
copy
()
s
=
samples_to
[
"samples"
].
clone
()
samples_to
=
samples_to
[
"samples"
]
samples_from
=
samples_from
[
"samples"
]
if
feather
==
0
:
if
feather
==
0
:
s
[:,:,
y
:
y
+
samples_from
.
shape
[
2
],
x
:
x
+
samples_from
.
shape
[
3
]]
=
samples_from
[:,:,:
samples_to
.
shape
[
2
]
-
y
,
:
samples_to
.
shape
[
3
]
-
x
]
s
[:,:,
y
:
y
+
samples_from
.
shape
[
2
],
x
:
x
+
samples_from
.
shape
[
3
]]
=
samples_from
[:,:,:
samples_to
.
shape
[
2
]
-
y
,
:
samples_to
.
shape
[
3
]
-
x
]
else
:
else
:
s_from
=
samples_from
[:,:,:
samples_to
.
shape
[
2
]
-
y
,
:
samples_to
.
shape
[
3
]
-
x
]
sample
s_from
=
samples_from
[:,:,:
samples_to
.
shape
[
2
]
-
y
,
:
samples_to
.
shape
[
3
]
-
x
]
mask
=
torch
.
ones_like
(
s_from
)
mask
=
torch
.
ones_like
(
s
amples
_from
)
for
t
in
range
(
feather
):
for
t
in
range
(
feather
):
if
y
!=
0
:
if
y
!=
0
:
mask
[:,:,
t
:
1
+
t
,:]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
mask
[:,:,
t
:
1
+
t
,:]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
...
@@ -331,7 +398,8 @@ class LatentComposite:
...
@@ -331,7 +398,8 @@ class LatentComposite:
mask
[:,:,:,
mask
.
shape
[
3
]
-
1
-
t
:
mask
.
shape
[
3
]
-
t
]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
mask
[:,:,:,
mask
.
shape
[
3
]
-
1
-
t
:
mask
.
shape
[
3
]
-
t
]
*=
((
1.0
/
feather
)
*
(
t
+
1
))
rev_mask
=
torch
.
ones_like
(
mask
)
-
mask
rev_mask
=
torch
.
ones_like
(
mask
)
-
mask
s
[:,:,
y
:
y
+
samples_from
.
shape
[
2
],
x
:
x
+
samples_from
.
shape
[
3
]]
=
samples_from
[:,:,:
samples_to
.
shape
[
2
]
-
y
,
:
samples_to
.
shape
[
3
]
-
x
]
*
mask
+
s
[:,:,
y
:
y
+
samples_from
.
shape
[
2
],
x
:
x
+
samples_from
.
shape
[
3
]]
*
rev_mask
s
[:,:,
y
:
y
+
samples_from
.
shape
[
2
],
x
:
x
+
samples_from
.
shape
[
3
]]
=
samples_from
[:,:,:
samples_to
.
shape
[
2
]
-
y
,
:
samples_to
.
shape
[
3
]
-
x
]
*
mask
+
s
[:,:,
y
:
y
+
samples_from
.
shape
[
2
],
x
:
x
+
samples_from
.
shape
[
3
]]
*
rev_mask
return
(
s
,)
samples_out
[
"samples"
]
=
s
return
(
samples_out
,)
class
LatentCrop
:
class
LatentCrop
:
@
classmethod
@
classmethod
...
@@ -348,6 +416,8 @@ class LatentCrop:
...
@@ -348,6 +416,8 @@ class LatentCrop:
CATEGORY
=
"latent"
CATEGORY
=
"latent"
def
crop
(
self
,
samples
,
width
,
height
,
x
,
y
):
def
crop
(
self
,
samples
,
width
,
height
,
x
,
y
):
s
=
samples
.
copy
()
samples
=
samples
[
'samples'
]
x
=
x
//
8
x
=
x
//
8
y
=
y
//
8
y
=
y
//
8
...
@@ -371,15 +441,43 @@ class LatentCrop:
...
@@ -371,15 +441,43 @@ class LatentCrop:
#make sure size is always multiple of 64
#make sure size is always multiple of 64
x
,
to_x
=
enforce_image_dim
(
x
,
to_x
,
samples
.
shape
[
3
])
x
,
to_x
=
enforce_image_dim
(
x
,
to_x
,
samples
.
shape
[
3
])
y
,
to_y
=
enforce_image_dim
(
y
,
to_y
,
samples
.
shape
[
2
])
y
,
to_y
=
enforce_image_dim
(
y
,
to_y
,
samples
.
shape
[
2
])
s
=
samples
[:,:,
y
:
to_y
,
x
:
to_x
]
s
[
'samples'
]
=
samples
[:,:,
y
:
to_y
,
x
:
to_x
]
return
(
s
,)
return
(
s
,)
def
common_ksampler
(
device
,
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
):
class
SetLatentNoiseMask
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"samples"
:
(
"LATENT"
,),
"mask"
:
(
"MASK"
,),
}}
RETURN_TYPES
=
(
"LATENT"
,)
FUNCTION
=
"set_mask"
CATEGORY
=
"latent/inpaint"
def
set_mask
(
self
,
samples
,
mask
):
s
=
samples
.
copy
()
s
[
"noise_mask"
]
=
mask
return
(
s
,)
def
common_ksampler
(
device
,
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
):
latent_image
=
latent
[
"samples"
]
noise_mask
=
None
if
disable_noise
:
if
disable_noise
:
noise
=
torch
.
zeros
(
latent_image
.
size
(),
dtype
=
latent_image
.
dtype
,
layout
=
latent_image
.
layout
,
device
=
"cpu"
)
noise
=
torch
.
zeros
(
latent_image
.
size
(),
dtype
=
latent_image
.
dtype
,
layout
=
latent_image
.
layout
,
device
=
"cpu"
)
else
:
else
:
noise
=
torch
.
randn
(
latent_image
.
size
(),
dtype
=
latent_image
.
dtype
,
layout
=
latent_image
.
layout
,
generator
=
torch
.
manual_seed
(
seed
),
device
=
"cpu"
)
noise
=
torch
.
randn
(
latent_image
.
size
(),
dtype
=
latent_image
.
dtype
,
layout
=
latent_image
.
layout
,
generator
=
torch
.
manual_seed
(
seed
),
device
=
"cpu"
)
if
"noise_mask"
in
latent
:
noise_mask
=
latent
[
'noise_mask'
]
noise_mask
=
torch
.
nn
.
functional
.
interpolate
(
noise_mask
[
None
,
None
,],
size
=
(
noise
.
shape
[
2
],
noise
.
shape
[
3
]),
mode
=
"bilinear"
)
noise_mask
=
noise_mask
.
round
()
noise_mask
=
torch
.
cat
([
noise_mask
]
*
noise
.
shape
[
1
],
dim
=
1
)
noise_mask
=
torch
.
cat
([
noise_mask
]
*
noise
.
shape
[
0
])
noise_mask
=
noise_mask
.
to
(
device
)
real_model
=
None
real_model
=
None
if
device
!=
"cpu"
:
if
device
!=
"cpu"
:
model_management
.
load_model_gpu
(
model
)
model_management
.
load_model_gpu
(
model
)
...
@@ -393,29 +491,40 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
...
@@ -393,29 +491,40 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
positive_copy
=
[]
positive_copy
=
[]
negative_copy
=
[]
negative_copy
=
[]
control_nets
=
[]
for
p
in
positive
:
for
p
in
positive
:
t
=
p
[
0
]
t
=
p
[
0
]
if
t
.
shape
[
0
]
<
noise
.
shape
[
0
]:
if
t
.
shape
[
0
]
<
noise
.
shape
[
0
]:
t
=
torch
.
cat
([
t
]
*
noise
.
shape
[
0
])
t
=
torch
.
cat
([
t
]
*
noise
.
shape
[
0
])
t
=
t
.
to
(
device
)
t
=
t
.
to
(
device
)
if
'control'
in
p
[
1
]:
control_nets
+=
[
p
[
1
][
'control'
]]
positive_copy
+=
[[
t
]
+
p
[
1
:]]
positive_copy
+=
[[
t
]
+
p
[
1
:]]
for
n
in
negative
:
for
n
in
negative
:
t
=
n
[
0
]
t
=
n
[
0
]
if
t
.
shape
[
0
]
<
noise
.
shape
[
0
]:
if
t
.
shape
[
0
]
<
noise
.
shape
[
0
]:
t
=
torch
.
cat
([
t
]
*
noise
.
shape
[
0
])
t
=
torch
.
cat
([
t
]
*
noise
.
shape
[
0
])
t
=
t
.
to
(
device
)
t
=
t
.
to
(
device
)
if
'control'
in
p
[
1
]:
control_nets
+=
[
p
[
1
][
'control'
]]
negative_copy
+=
[[
t
]
+
n
[
1
:]]
negative_copy
+=
[[
t
]
+
n
[
1
:]]
model_management
.
load_controlnet_gpu
(
list
(
map
(
lambda
a
:
a
.
control_model
,
control_nets
)))
if
sampler_name
in
comfy
.
samplers
.
KSampler
.
SAMPLERS
:
if
sampler_name
in
comfy
.
samplers
.
KSampler
.
SAMPLERS
:
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
)
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
)
else
:
else
:
#other samplers
#other samplers
pass
pass
samples
=
sampler
.
sample
(
noise
,
positive_copy
,
negative_copy
,
cfg
=
cfg
,
latent_image
=
latent_image
,
start_step
=
start_step
,
last_step
=
last_step
,
force_full_denoise
=
force_full_denoise
)
samples
=
sampler
.
sample
(
noise
,
positive_copy
,
negative_copy
,
cfg
=
cfg
,
latent_image
=
latent_image
,
start_step
=
start_step
,
last_step
=
last_step
,
force_full_denoise
=
force_full_denoise
,
denoise_mask
=
noise_mask
)
samples
=
samples
.
cpu
()
samples
=
samples
.
cpu
()
for
c
in
control_nets
:
c
.
cleanup
()
return
(
samples
,
)
out
=
latent
.
copy
()
out
[
"samples"
]
=
samples
return
(
out
,
)
class
KSampler
:
class
KSampler
:
def
__init__
(
self
,
device
=
"cuda"
):
def
__init__
(
self
,
device
=
"cuda"
):
...
@@ -532,7 +641,7 @@ class LoadImage:
...
@@ -532,7 +641,7 @@ class LoadImage:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
return
{
"required"
:
return
{
"required"
:
{
"image"
:
(
os
.
listdir
(
s
.
input_dir
),
)},
{
"image"
:
(
sorted
(
os
.
listdir
(
s
.
input_dir
)
)
,
)},
}
}
CATEGORY
=
"image"
CATEGORY
=
"image"
...
@@ -541,10 +650,11 @@ class LoadImage:
...
@@ -541,10 +650,11 @@ class LoadImage:
FUNCTION
=
"load_image"
FUNCTION
=
"load_image"
def
load_image
(
self
,
image
):
def
load_image
(
self
,
image
):
image_path
=
os
.
path
.
join
(
self
.
input_dir
,
image
)
image_path
=
os
.
path
.
join
(
self
.
input_dir
,
image
)
image
=
Image
.
open
(
image_path
).
convert
(
"RGB"
)
i
=
Image
.
open
(
image_path
)
image
=
i
.
convert
(
"RGB"
)
image
=
np
.
array
(
image
).
astype
(
np
.
float32
)
/
255.0
image
=
np
.
array
(
image
).
astype
(
np
.
float32
)
/
255.0
image
=
torch
.
from_numpy
(
image
[
None
]
)[
None
,]
image
=
torch
.
from_numpy
(
image
)[
None
,]
return
image
return
(
image
,)
@
classmethod
@
classmethod
def
IS_CHANGED
(
s
,
image
):
def
IS_CHANGED
(
s
,
image
):
...
@@ -554,6 +664,41 @@ class LoadImage:
...
@@ -554,6 +664,41 @@ class LoadImage:
m
.
update
(
f
.
read
())
m
.
update
(
f
.
read
())
return
m
.
digest
().
hex
()
return
m
.
digest
().
hex
()
class
LoadImageMask
:
input_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"input"
)
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"image"
:
(
os
.
listdir
(
s
.
input_dir
),
),
"channel"
:
([
"alpha"
,
"red"
,
"green"
,
"blue"
],
),}
}
CATEGORY
=
"image"
RETURN_TYPES
=
(
"MASK"
,)
FUNCTION
=
"load_image"
def
load_image
(
self
,
image
,
channel
):
image_path
=
os
.
path
.
join
(
self
.
input_dir
,
image
)
i
=
Image
.
open
(
image_path
)
mask
=
None
c
=
channel
[
0
].
upper
()
if
c
in
i
.
getbands
():
mask
=
np
.
array
(
i
.
getchannel
(
c
)).
astype
(
np
.
float32
)
/
255.0
mask
=
torch
.
from_numpy
(
mask
)
if
c
==
'A'
:
mask
=
1.
-
mask
else
:
mask
=
torch
.
zeros
((
64
,
64
),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
return
(
mask
,)
@
classmethod
def
IS_CHANGED
(
s
,
image
,
channel
):
image_path
=
os
.
path
.
join
(
s
.
input_dir
,
image
)
m
=
hashlib
.
sha256
()
with
open
(
image_path
,
'rb'
)
as
f
:
m
.
update
(
f
.
read
())
return
m
.
digest
().
hex
()
class
ImageScale
:
class
ImageScale
:
upscale_methods
=
[
"nearest-exact"
,
"bilinear"
,
"area"
]
upscale_methods
=
[
"nearest-exact"
,
"bilinear"
,
"area"
]
crop_methods
=
[
"disabled"
,
"center"
]
crop_methods
=
[
"disabled"
,
"center"
]
...
@@ -571,7 +716,7 @@ class ImageScale:
...
@@ -571,7 +716,7 @@ class ImageScale:
def
upscale
(
self
,
image
,
upscale_method
,
width
,
height
,
crop
):
def
upscale
(
self
,
image
,
upscale_method
,
width
,
height
,
crop
):
samples
=
image
.
movedim
(
-
1
,
1
)
samples
=
image
.
movedim
(
-
1
,
1
)
s
=
common_upscale
(
samples
,
width
,
height
,
upscale_method
,
crop
)
s
=
comfy
.
utils
.
common_upscale
(
samples
,
width
,
height
,
upscale_method
,
crop
)
s
=
s
.
movedim
(
1
,
-
1
)
s
=
s
.
movedim
(
1
,
-
1
)
return
(
s
,)
return
(
s
,)
...
@@ -581,21 +726,26 @@ NODE_CLASS_MAPPINGS = {
...
@@ -581,21 +726,26 @@ NODE_CLASS_MAPPINGS = {
"CLIPTextEncode"
:
CLIPTextEncode
,
"CLIPTextEncode"
:
CLIPTextEncode
,
"VAEDecode"
:
VAEDecode
,
"VAEDecode"
:
VAEDecode
,
"VAEEncode"
:
VAEEncode
,
"VAEEncode"
:
VAEEncode
,
"VAEEncodeForInpaint"
:
VAEEncodeForInpaint
,
"VAELoader"
:
VAELoader
,
"VAELoader"
:
VAELoader
,
"EmptyLatentImage"
:
EmptyLatentImage
,
"EmptyLatentImage"
:
EmptyLatentImage
,
"LatentUpscale"
:
LatentUpscale
,
"LatentUpscale"
:
LatentUpscale
,
"SaveImage"
:
SaveImage
,
"SaveImage"
:
SaveImage
,
"LoadImage"
:
LoadImage
,
"LoadImage"
:
LoadImage
,
"LoadImageMask"
:
LoadImageMask
,
"ImageScale"
:
ImageScale
,
"ImageScale"
:
ImageScale
,
"ConditioningCombine"
:
ConditioningCombine
,
"ConditioningCombine"
:
ConditioningCombine
,
"ConditioningSetArea"
:
ConditioningSetArea
,
"ConditioningSetArea"
:
ConditioningSetArea
,
"KSamplerAdvanced"
:
KSamplerAdvanced
,
"KSamplerAdvanced"
:
KSamplerAdvanced
,
"SetLatentNoiseMask"
:
SetLatentNoiseMask
,
"LatentComposite"
:
LatentComposite
,
"LatentComposite"
:
LatentComposite
,
"LatentRotate"
:
LatentRotate
,
"LatentRotate"
:
LatentRotate
,
"LatentFlip"
:
LatentFlip
,
"LatentFlip"
:
LatentFlip
,
"LatentCrop"
:
LatentCrop
,
"LatentCrop"
:
LatentCrop
,
"LoraLoader"
:
LoraLoader
,
"LoraLoader"
:
LoraLoader
,
"CLIPLoader"
:
CLIPLoader
,
"CLIPLoader"
:
CLIPLoader
,
"ControlNetApply"
:
ControlNetApply
,
"ControlNetLoader"
:
ControlNetLoader
,
}
}
CUSTOM_NODE_PATH
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"custom_nodes"
)
CUSTOM_NODE_PATH
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"custom_nodes"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment