Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
38833ceb
Commit
38833ceb
authored
Apr 01, 2023
by
pythongosssss
Browse files
Merge remote-tracking branch 'origin/master' into custom_routes
parents
313f1f83
0d972b85
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
200 additions
and
41 deletions
+200
-41
comfy/ldm/models/diffusion/ddim.py
comfy/ldm/models/diffusion/ddim.py
+7
-7
comfy/ldm/models/diffusion/ddpm.py
comfy/ldm/models/diffusion/ddpm.py
+9
-9
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+18
-7
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+7
-6
comfy/ldm/modules/tomesd.py
comfy/ldm/modules/tomesd.py
+117
-0
comfy/samplers.py
comfy/samplers.py
+14
-10
comfy/sd.py
comfy/sd.py
+9
-0
nodes.py
nodes.py
+18
-1
web/extensions/core/widgetInputs.js
web/extensions/core/widgetInputs.js
+1
-1
No files found.
comfy/ldm/models/diffusion/ddim.py
View file @
38833ceb
...
@@ -78,7 +78,7 @@ class DDIMSampler(object):
...
@@ -78,7 +78,7 @@ class DDIMSampler(object):
dynamic_threshold
=
None
,
dynamic_threshold
=
None
,
ucg_schedule
=
None
,
ucg_schedule
=
None
,
denoise_function
=
None
,
denoise_function
=
None
,
cond_concat
=
None
,
extra_args
=
None
,
to_zero
=
True
,
to_zero
=
True
,
end_step
=
None
,
end_step
=
None
,
**
kwargs
**
kwargs
...
@@ -101,7 +101,7 @@ class DDIMSampler(object):
...
@@ -101,7 +101,7 @@ class DDIMSampler(object):
dynamic_threshold
=
dynamic_threshold
,
dynamic_threshold
=
dynamic_threshold
,
ucg_schedule
=
ucg_schedule
,
ucg_schedule
=
ucg_schedule
,
denoise_function
=
denoise_function
,
denoise_function
=
denoise_function
,
cond_concat
=
cond_concat
,
extra_args
=
extra_args
,
to_zero
=
to_zero
,
to_zero
=
to_zero
,
end_step
=
end_step
end_step
=
end_step
)
)
...
@@ -174,7 +174,7 @@ class DDIMSampler(object):
...
@@ -174,7 +174,7 @@ class DDIMSampler(object):
dynamic_threshold
=
dynamic_threshold
,
dynamic_threshold
=
dynamic_threshold
,
ucg_schedule
=
ucg_schedule
,
ucg_schedule
=
ucg_schedule
,
denoise_function
=
None
,
denoise_function
=
None
,
cond_concat
=
None
extra_args
=
None
)
)
return
samples
,
intermediates
return
samples
,
intermediates
...
@@ -185,7 +185,7 @@ class DDIMSampler(object):
...
@@ -185,7 +185,7 @@ class DDIMSampler(object):
mask
=
None
,
x0
=
None
,
img_callback
=
None
,
log_every_t
=
100
,
mask
=
None
,
x0
=
None
,
img_callback
=
None
,
log_every_t
=
100
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
dynamic_threshold
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
dynamic_threshold
=
None
,
ucg_schedule
=
None
,
denoise_function
=
None
,
cond_concat
=
None
,
to_zero
=
True
,
end_step
=
None
):
ucg_schedule
=
None
,
denoise_function
=
None
,
extra_args
=
None
,
to_zero
=
True
,
end_step
=
None
):
device
=
self
.
model
.
betas
.
device
device
=
self
.
model
.
betas
.
device
b
=
shape
[
0
]
b
=
shape
[
0
]
if
x_T
is
None
:
if
x_T
is
None
:
...
@@ -225,7 +225,7 @@ class DDIMSampler(object):
...
@@ -225,7 +225,7 @@ class DDIMSampler(object):
corrector_kwargs
=
corrector_kwargs
,
corrector_kwargs
=
corrector_kwargs
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_guidance_scale
=
unconditional_guidance_scale
,
unconditional_conditioning
=
unconditional_conditioning
,
unconditional_conditioning
=
unconditional_conditioning
,
dynamic_threshold
=
dynamic_threshold
,
denoise_function
=
denoise_function
,
cond_concat
=
cond_concat
)
dynamic_threshold
=
dynamic_threshold
,
denoise_function
=
denoise_function
,
extra_args
=
extra_args
)
img
,
pred_x0
=
outs
img
,
pred_x0
=
outs
if
callback
:
callback
(
i
)
if
callback
:
callback
(
i
)
if
img_callback
:
img_callback
(
pred_x0
,
i
)
if
img_callback
:
img_callback
(
pred_x0
,
i
)
...
@@ -249,11 +249,11 @@ class DDIMSampler(object):
...
@@ -249,11 +249,11 @@ class DDIMSampler(object):
def
p_sample_ddim
(
self
,
x
,
c
,
t
,
index
,
repeat_noise
=
False
,
use_original_steps
=
False
,
quantize_denoised
=
False
,
def
p_sample_ddim
(
self
,
x
,
c
,
t
,
index
,
repeat_noise
=
False
,
use_original_steps
=
False
,
quantize_denoised
=
False
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
unconditional_guidance_scale
=
1.
,
unconditional_conditioning
=
None
,
dynamic_threshold
=
None
,
denoise_function
=
None
,
cond_concat
=
None
):
dynamic_threshold
=
None
,
denoise_function
=
None
,
extra_args
=
None
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
if
denoise_function
is
not
None
:
if
denoise_function
is
not
None
:
model_output
=
denoise_function
(
self
.
model
.
apply_model
,
x
,
t
,
unconditional_conditioning
,
c
,
unconditional_guidance_scale
,
cond_concat
)
model_output
=
denoise_function
(
self
.
model
.
apply_model
,
x
,
t
,
**
extra_args
)
elif
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
elif
unconditional_conditioning
is
None
or
unconditional_guidance_scale
==
1.
:
model_output
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
model_output
=
self
.
model
.
apply_model
(
x
,
t
,
c
)
else
:
else
:
...
...
comfy/ldm/models/diffusion/ddpm.py
View file @
38833ceb
...
@@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module):
...
@@ -1317,12 +1317,12 @@ class DiffusionWrapper(torch.nn.Module):
self
.
conditioning_key
=
conditioning_key
self
.
conditioning_key
=
conditioning_key
assert
self
.
conditioning_key
in
[
None
,
'concat'
,
'crossattn'
,
'hybrid'
,
'adm'
,
'hybrid-adm'
,
'crossattn-adm'
]
assert
self
.
conditioning_key
in
[
None
,
'concat'
,
'crossattn'
,
'hybrid'
,
'adm'
,
'hybrid-adm'
,
'crossattn-adm'
]
def
forward
(
self
,
x
,
t
,
c_concat
:
list
=
None
,
c_crossattn
:
list
=
None
,
c_adm
=
None
,
control
=
None
):
def
forward
(
self
,
x
,
t
,
c_concat
:
list
=
None
,
c_crossattn
:
list
=
None
,
c_adm
=
None
,
control
=
None
,
transformer_options
=
{}
):
if
self
.
conditioning_key
is
None
:
if
self
.
conditioning_key
is
None
:
out
=
self
.
diffusion_model
(
x
,
t
,
control
=
control
)
out
=
self
.
diffusion_model
(
x
,
t
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'concat'
:
elif
self
.
conditioning_key
==
'concat'
:
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
out
=
self
.
diffusion_model
(
xc
,
t
,
control
=
control
)
out
=
self
.
diffusion_model
(
xc
,
t
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'crossattn'
:
elif
self
.
conditioning_key
==
'crossattn'
:
if
not
self
.
sequential_cross_attn
:
if
not
self
.
sequential_cross_attn
:
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
...
@@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module):
...
@@ -1332,25 +1332,25 @@ class DiffusionWrapper(torch.nn.Module):
# TorchScript changes names of the arguments
# TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce
# with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out
=
self
.
scripted_diffusion_model
(
x
,
t
,
cc
,
control
=
control
)
out
=
self
.
scripted_diffusion_model
(
x
,
t
,
cc
,
control
=
control
,
transformer_options
=
transformer_options
)
else
:
else
:
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
control
=
control
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'hybrid'
:
elif
self
.
conditioning_key
==
'hybrid'
:
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
control
=
control
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'hybrid-adm'
:
elif
self
.
conditioning_key
==
'hybrid-adm'
:
assert
c_adm
is
not
None
assert
c_adm
is
not
None
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'crossattn-adm'
:
elif
self
.
conditioning_key
==
'crossattn-adm'
:
assert
c_adm
is
not
None
assert
c_adm
is
not
None
cc
=
torch
.
cat
(
c_crossattn
,
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
,
y
=
c_adm
,
control
=
control
,
transformer_options
=
transformer_options
)
elif
self
.
conditioning_key
==
'adm'
:
elif
self
.
conditioning_key
==
'adm'
:
cc
=
c_crossattn
[
0
]
cc
=
c_crossattn
[
0
]
out
=
self
.
diffusion_model
(
x
,
t
,
y
=
cc
,
control
=
control
)
out
=
self
.
diffusion_model
(
x
,
t
,
y
=
cc
,
control
=
control
,
transformer_options
=
transformer_options
)
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
comfy/ldm/modules/attention.py
View file @
38833ceb
...
@@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention
...
@@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention
import
model_management
import
model_management
from
.
import
tomesd
if
model_management
.
xformers_enabled
():
if
model_management
.
xformers_enabled
():
import
xformers
import
xformers
...
@@ -504,12 +505,22 @@ class BasicTransformerBlock(nn.Module):
...
@@ -504,12 +505,22 @@ class BasicTransformerBlock(nn.Module):
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
checkpoint
=
checkpoint
self
.
checkpoint
=
checkpoint
def
forward
(
self
,
x
,
context
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}
):
return
checkpoint
(
self
.
_forward
,
(
x
,
context
),
self
.
parameters
(),
self
.
checkpoint
)
return
checkpoint
(
self
.
_forward
,
(
x
,
context
,
transformer_options
),
self
.
parameters
(),
self
.
checkpoint
)
def
_forward
(
self
,
x
,
context
=
None
):
def
_forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}):
x
=
self
.
attn1
(
self
.
norm1
(
x
),
context
=
context
if
self
.
disable_self_attn
else
None
)
+
x
n
=
self
.
norm1
(
x
)
x
=
self
.
attn2
(
self
.
norm2
(
x
),
context
=
context
)
+
x
if
"tomesd"
in
transformer_options
:
m
,
u
=
tomesd
.
get_functions
(
x
,
transformer_options
[
"tomesd"
][
"ratio"
],
transformer_options
[
"original_shape"
])
n
=
u
(
self
.
attn1
(
m
(
n
),
context
=
context
if
self
.
disable_self_attn
else
None
))
else
:
n
=
self
.
attn1
(
n
,
context
=
context
if
self
.
disable_self_attn
else
None
)
x
+=
n
n
=
self
.
norm2
(
x
)
n
=
self
.
attn2
(
n
,
context
=
context
)
x
+=
n
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
return
x
return
x
...
@@ -557,7 +568,7 @@ class SpatialTransformer(nn.Module):
...
@@ -557,7 +568,7 @@ class SpatialTransformer(nn.Module):
self
.
proj_out
=
zero_module
(
nn
.
Linear
(
in_channels
,
inner_dim
))
self
.
proj_out
=
zero_module
(
nn
.
Linear
(
in_channels
,
inner_dim
))
self
.
use_linear
=
use_linear
self
.
use_linear
=
use_linear
def
forward
(
self
,
x
,
context
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
transformer_options
=
{}
):
# note: if no context is given, cross-attention defaults to self-attention
# note: if no context is given, cross-attention defaults to self-attention
if
not
isinstance
(
context
,
list
):
if
not
isinstance
(
context
,
list
):
context
=
[
context
]
context
=
[
context
]
...
@@ -570,7 +581,7 @@ class SpatialTransformer(nn.Module):
...
@@ -570,7 +581,7 @@ class SpatialTransformer(nn.Module):
if
self
.
use_linear
:
if
self
.
use_linear
:
x
=
self
.
proj_in
(
x
)
x
=
self
.
proj_in
(
x
)
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
for
i
,
block
in
enumerate
(
self
.
transformer_blocks
):
x
=
block
(
x
,
context
=
context
[
i
])
x
=
block
(
x
,
context
=
context
[
i
]
,
transformer_options
=
transformer_options
)
if
self
.
use_linear
:
if
self
.
use_linear
:
x
=
self
.
proj_out
(
x
)
x
=
self
.
proj_out
(
x
)
x
=
rearrange
(
x
,
'b (h w) c -> b c h w'
,
h
=
h
,
w
=
w
).
contiguous
()
x
=
rearrange
(
x
,
'b (h w) c -> b c h w'
,
h
=
h
,
w
=
w
).
contiguous
()
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
38833ceb
...
@@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
...
@@ -76,12 +76,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input.
support it as an extra input.
"""
"""
def
forward
(
self
,
x
,
emb
,
context
=
None
):
def
forward
(
self
,
x
,
emb
,
context
=
None
,
transformer_options
=
{}
):
for
layer
in
self
:
for
layer
in
self
:
if
isinstance
(
layer
,
TimestepBlock
):
if
isinstance
(
layer
,
TimestepBlock
):
x
=
layer
(
x
,
emb
)
x
=
layer
(
x
,
emb
)
elif
isinstance
(
layer
,
SpatialTransformer
):
elif
isinstance
(
layer
,
SpatialTransformer
):
x
=
layer
(
x
,
context
)
x
=
layer
(
x
,
context
,
transformer_options
)
else
:
else
:
x
=
layer
(
x
)
x
=
layer
(
x
)
return
x
return
x
...
@@ -753,7 +753,7 @@ class UNetModel(nn.Module):
...
@@ -753,7 +753,7 @@ class UNetModel(nn.Module):
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
middle_block
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
self
.
output_blocks
.
apply
(
convert_module_to_f32
)
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
control
=
None
,
**
kwargs
):
def
forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
control
=
None
,
transformer_options
=
{},
**
kwargs
):
"""
"""
Apply the model to an input batch.
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param x: an [N x C x ...] Tensor of inputs.
...
@@ -762,6 +762,7 @@ class UNetModel(nn.Module):
...
@@ -762,6 +762,7 @@ class UNetModel(nn.Module):
:param y: an [N] Tensor of labels, if class-conditional.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
:return: an [N x C x ...] Tensor of outputs.
"""
"""
transformer_options
[
"original_shape"
]
=
list
(
x
.
shape
)
assert
(
y
is
not
None
)
==
(
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
),
"must specify y if and only if the model is class-conditional"
...
@@ -775,13 +776,13 @@ class UNetModel(nn.Module):
...
@@ -775,13 +776,13 @@ class UNetModel(nn.Module):
h
=
x
.
type
(
self
.
dtype
)
h
=
x
.
type
(
self
.
dtype
)
for
id
,
module
in
enumerate
(
self
.
input_blocks
):
for
id
,
module
in
enumerate
(
self
.
input_blocks
):
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
,
transformer_options
)
if
control
is
not
None
and
'input'
in
control
and
len
(
control
[
'input'
])
>
0
:
if
control
is
not
None
and
'input'
in
control
and
len
(
control
[
'input'
])
>
0
:
ctrl
=
control
[
'input'
].
pop
()
ctrl
=
control
[
'input'
].
pop
()
if
ctrl
is
not
None
:
if
ctrl
is
not
None
:
h
+=
ctrl
h
+=
ctrl
hs
.
append
(
h
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
,
context
)
h
=
self
.
middle_block
(
h
,
emb
,
context
,
transformer_options
)
if
control
is
not
None
and
'middle'
in
control
and
len
(
control
[
'middle'
])
>
0
:
if
control
is
not
None
and
'middle'
in
control
and
len
(
control
[
'middle'
])
>
0
:
h
+=
control
[
'middle'
].
pop
()
h
+=
control
[
'middle'
].
pop
()
...
@@ -793,7 +794,7 @@ class UNetModel(nn.Module):
...
@@ -793,7 +794,7 @@ class UNetModel(nn.Module):
hsp
+=
ctrl
hsp
+=
ctrl
h
=
th
.
cat
([
h
,
hsp
],
dim
=
1
)
h
=
th
.
cat
([
h
,
hsp
],
dim
=
1
)
del
hsp
del
hsp
h
=
module
(
h
,
emb
,
context
)
h
=
module
(
h
,
emb
,
context
,
transformer_options
)
h
=
h
.
type
(
x
.
dtype
)
h
=
h
.
type
(
x
.
dtype
)
if
self
.
predict_codebook_ids
:
if
self
.
predict_codebook_ids
:
return
self
.
id_predictor
(
h
)
return
self
.
id_predictor
(
h
)
...
...
comfy/ldm/modules/tomesd.py
0 → 100644
View file @
38833ceb
import
torch
from
typing
import
Tuple
,
Callable
import
math
def
do_nothing
(
x
:
torch
.
Tensor
,
mode
:
str
=
None
):
return
x
def
bipartite_soft_matching_random2d
(
metric
:
torch
.
Tensor
,
w
:
int
,
h
:
int
,
sx
:
int
,
sy
:
int
,
r
:
int
,
no_rand
:
bool
=
False
)
->
Tuple
[
Callable
,
Callable
]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args:
- metric [B, N, C]: metric to use for similarity
- w: image width in tokens
- h: image height in tokens
- sx: stride in the x dimension for dst, must divide w
- sy: stride in the y dimension for dst, must divide h
- r: number of tokens to remove (by merging)
- no_rand: if true, disable randomness (use top left corner only)
"""
B
,
N
,
_
=
metric
.
shape
if
r
<=
0
:
return
do_nothing
,
do_nothing
with
torch
.
no_grad
():
hsy
,
wsx
=
h
//
sy
,
w
//
sx
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
idx_buffer
=
torch
.
zeros
(
1
,
hsy
,
wsx
,
sy
*
sx
,
1
,
device
=
metric
.
device
)
if
no_rand
:
rand_idx
=
torch
.
zeros
(
1
,
hsy
,
wsx
,
1
,
1
,
device
=
metric
.
device
,
dtype
=
torch
.
int64
)
else
:
rand_idx
=
torch
.
randint
(
sy
*
sx
,
size
=
(
1
,
hsy
,
wsx
,
1
,
1
),
device
=
metric
.
device
)
idx_buffer
.
scatter_
(
dim
=
3
,
index
=
rand_idx
,
src
=-
torch
.
ones_like
(
rand_idx
,
dtype
=
idx_buffer
.
dtype
))
idx_buffer
=
idx_buffer
.
view
(
1
,
hsy
,
wsx
,
sy
,
sx
,
1
).
transpose
(
2
,
3
).
reshape
(
1
,
N
,
1
)
rand_idx
=
idx_buffer
.
argsort
(
dim
=
1
)
num_dst
=
int
((
1
/
(
sx
*
sy
))
*
N
)
a_idx
=
rand_idx
[:,
num_dst
:,
:]
# src
b_idx
=
rand_idx
[:,
:
num_dst
,
:]
# dst
def
split
(
x
):
C
=
x
.
shape
[
-
1
]
src
=
x
.
gather
(
dim
=
1
,
index
=
a_idx
.
expand
(
B
,
N
-
num_dst
,
C
))
dst
=
x
.
gather
(
dim
=
1
,
index
=
b_idx
.
expand
(
B
,
num_dst
,
C
))
return
src
,
dst
metric
=
metric
/
metric
.
norm
(
dim
=-
1
,
keepdim
=
True
)
a
,
b
=
split
(
metric
)
scores
=
a
@
b
.
transpose
(
-
1
,
-
2
)
# Can't reduce more than the # tokens in src
r
=
min
(
a
.
shape
[
1
],
r
)
node_max
,
node_idx
=
scores
.
max
(
dim
=-
1
)
edge_idx
=
node_max
.
argsort
(
dim
=-
1
,
descending
=
True
)[...,
None
]
unm_idx
=
edge_idx
[...,
r
:,
:]
# Unmerged Tokens
src_idx
=
edge_idx
[...,
:
r
,
:]
# Merged Tokens
dst_idx
=
node_idx
[...,
None
].
gather
(
dim
=-
2
,
index
=
src_idx
)
def
merge
(
x
:
torch
.
Tensor
,
mode
=
"mean"
)
->
torch
.
Tensor
:
src
,
dst
=
split
(
x
)
n
,
t1
,
c
=
src
.
shape
unm
=
src
.
gather
(
dim
=-
2
,
index
=
unm_idx
.
expand
(
n
,
t1
-
r
,
c
))
src
=
src
.
gather
(
dim
=-
2
,
index
=
src_idx
.
expand
(
n
,
r
,
c
))
dst
=
dst
.
scatter_reduce
(
-
2
,
dst_idx
.
expand
(
n
,
r
,
c
),
src
,
reduce
=
mode
)
return
torch
.
cat
([
unm
,
dst
],
dim
=
1
)
def
unmerge
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
unm_len
=
unm_idx
.
shape
[
1
]
unm
,
dst
=
x
[...,
:
unm_len
,
:],
x
[...,
unm_len
:,
:]
_
,
_
,
c
=
unm
.
shape
src
=
dst
.
gather
(
dim
=-
2
,
index
=
dst_idx
.
expand
(
B
,
r
,
c
))
# Combine back to the original shape
out
=
torch
.
zeros
(
B
,
N
,
c
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
out
.
scatter_
(
dim
=-
2
,
index
=
b_idx
.
expand
(
B
,
num_dst
,
c
),
src
=
dst
)
out
.
scatter_
(
dim
=-
2
,
index
=
a_idx
.
expand
(
B
,
a_idx
.
shape
[
1
],
1
).
gather
(
dim
=
1
,
index
=
unm_idx
).
expand
(
B
,
unm_len
,
c
),
src
=
unm
)
out
.
scatter_
(
dim
=-
2
,
index
=
a_idx
.
expand
(
B
,
a_idx
.
shape
[
1
],
1
).
gather
(
dim
=
1
,
index
=
src_idx
).
expand
(
B
,
r
,
c
),
src
=
src
)
return
out
return
merge
,
unmerge
def
get_functions
(
x
,
ratio
,
original_shape
):
b
,
c
,
original_h
,
original_w
=
original_shape
original_tokens
=
original_h
*
original_w
downsample
=
int
(
math
.
sqrt
(
original_tokens
//
x
.
shape
[
1
]))
stride_x
=
2
stride_y
=
2
max_downsample
=
1
if
downsample
<=
max_downsample
:
w
=
original_w
//
downsample
h
=
original_h
//
downsample
r
=
int
(
x
.
shape
[
1
]
*
ratio
)
no_rand
=
False
m
,
u
=
bipartite_soft_matching_random2d
(
x
,
w
,
h
,
stride_x
,
stride_y
,
r
,
no_rand
)
return
m
,
u
nothing
=
lambda
y
:
y
return
nothing
,
nothing
comfy/samplers.py
View file @
38833ceb
...
@@ -26,7 +26,7 @@ class CFGDenoiser(torch.nn.Module):
...
@@ -26,7 +26,7 @@ class CFGDenoiser(torch.nn.Module):
#The main sampling function shared by all the samplers
#The main sampling function shared by all the samplers
#Returns predicted noise
#Returns predicted noise
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
=
None
):
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
=
None
,
model_options
=
{}
):
def
get_area_and_mult
(
cond
,
x_in
,
cond_concat_in
,
timestep_in
):
def
get_area_and_mult
(
cond
,
x_in
,
cond_concat_in
,
timestep_in
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
strength
=
1.0
...
@@ -104,7 +104,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
...
@@ -104,7 +104,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
out
[
'c_concat'
]
=
[
torch
.
cat
(
c_concat
)]
out
[
'c_concat'
]
=
[
torch
.
cat
(
c_concat
)]
return
out
return
out
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
timestep
,
max_total_area
,
cond_concat_in
):
def
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x_in
,
timestep
,
max_total_area
,
cond_concat_in
,
model_options
):
out_cond
=
torch
.
zeros_like
(
x_in
)
out_cond
=
torch
.
zeros_like
(
x_in
)
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
out_count
=
torch
.
ones_like
(
x_in
)
/
100000.0
...
@@ -169,6 +169,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
...
@@ -169,6 +169,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if
control
is
not
None
:
if
control
is
not
None
:
c
[
'control'
]
=
control
.
get_control
(
input_x
,
timestep_
,
c
[
'c_crossattn'
],
len
(
cond_or_uncond
))
c
[
'control'
]
=
control
.
get_control
(
input_x
,
timestep_
,
c
[
'c_crossattn'
],
len
(
cond_or_uncond
))
if
'transformer_options'
in
model_options
:
c
[
'transformer_options'
]
=
model_options
[
'transformer_options'
]
output
=
model_function
(
input_x
,
timestep_
,
cond
=
c
).
chunk
(
batch_chunks
)
output
=
model_function
(
input_x
,
timestep_
,
cond
=
c
).
chunk
(
batch_chunks
)
del
input_x
del
input_x
...
@@ -192,7 +195,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
...
@@ -192,7 +195,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area
=
model_management
.
maximum_batch_area
()
max_total_area
=
model_management
.
maximum_batch_area
()
cond
,
uncond
=
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x
,
timestep
,
max_total_area
,
cond_concat
)
cond
,
uncond
=
calc_cond_uncond_batch
(
model_function
,
cond
,
uncond
,
x
,
timestep
,
max_total_area
,
cond_concat
,
model_options
)
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
return
uncond
+
(
cond
-
uncond
)
*
cond_scale
...
@@ -209,8 +212,8 @@ class CFGNoisePredictor(torch.nn.Module):
...
@@ -209,8 +212,8 @@ class CFGNoisePredictor(torch.nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
inner_model
=
model
self
.
inner_model
=
model
self
.
alphas_cumprod
=
model
.
alphas_cumprod
self
.
alphas_cumprod
=
model
.
alphas_cumprod
def
apply_model
(
self
,
x
,
timestep
,
cond
,
uncond
,
cond_scale
,
cond_concat
=
None
):
def
apply_model
(
self
,
x
,
timestep
,
cond
,
uncond
,
cond_scale
,
cond_concat
=
None
,
model_options
=
{}
):
out
=
sampling_function
(
self
.
inner_model
.
apply_model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
)
out
=
sampling_function
(
self
.
inner_model
.
apply_model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
,
model_options
=
model_options
)
return
out
return
out
...
@@ -218,11 +221,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
...
@@ -218,11 +221,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
super
().
__init__
()
super
().
__init__
()
self
.
inner_model
=
model
self
.
inner_model
=
model
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
denoise_mask
,
cond_concat
=
None
):
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
denoise_mask
,
cond_concat
=
None
,
model_options
=
{}
):
if
denoise_mask
is
not
None
:
if
denoise_mask
is
not
None
:
latent_mask
=
1.
-
denoise_mask
latent_mask
=
1.
-
denoise_mask
x
=
x
*
denoise_mask
+
(
self
.
latent_image
+
self
.
noise
*
sigma
.
reshape
([
sigma
.
shape
[
0
]]
+
[
1
]
*
(
len
(
self
.
noise
.
shape
)
-
1
)))
*
latent_mask
x
=
x
*
denoise_mask
+
(
self
.
latent_image
+
self
.
noise
*
sigma
.
reshape
([
sigma
.
shape
[
0
]]
+
[
1
]
*
(
len
(
self
.
noise
.
shape
)
-
1
)))
*
latent_mask
out
=
self
.
inner_model
(
x
,
sigma
,
cond
=
cond
,
uncond
=
uncond
,
cond_scale
=
cond_scale
,
cond_concat
=
cond_concat
)
out
=
self
.
inner_model
(
x
,
sigma
,
cond
=
cond
,
uncond
=
uncond
,
cond_scale
=
cond_scale
,
cond_concat
=
cond_concat
,
model_options
=
model_options
)
if
denoise_mask
is
not
None
:
if
denoise_mask
is
not
None
:
out
*=
denoise_mask
out
*=
denoise_mask
...
@@ -330,7 +333,7 @@ class KSampler:
...
@@ -330,7 +333,7 @@ class KSampler:
"lms"
,
"dpm_fast"
,
"dpm_adaptive"
,
"dpmpp_2s_ancestral"
,
"dpmpp_sde"
,
"lms"
,
"dpm_fast"
,
"dpm_adaptive"
,
"dpmpp_2s_ancestral"
,
"dpmpp_sde"
,
"dpmpp_2m"
,
"ddim"
,
"uni_pc"
,
"uni_pc_bh2"
]
"dpmpp_2m"
,
"ddim"
,
"uni_pc"
,
"uni_pc_bh2"
]
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
):
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
,
model_options
=
{}
):
self
.
model
=
model
self
.
model
=
model
self
.
model_denoise
=
CFGNoisePredictor
(
self
.
model
)
self
.
model_denoise
=
CFGNoisePredictor
(
self
.
model
)
if
self
.
model
.
parameterization
==
"v"
:
if
self
.
model
.
parameterization
==
"v"
:
...
@@ -350,6 +353,7 @@ class KSampler:
...
@@ -350,6 +353,7 @@ class KSampler:
self
.
sigma_max
=
float
(
self
.
model_wrap
.
sigma_max
)
self
.
sigma_max
=
float
(
self
.
model_wrap
.
sigma_max
)
self
.
set_steps
(
steps
,
denoise
)
self
.
set_steps
(
steps
,
denoise
)
self
.
denoise
=
denoise
self
.
denoise
=
denoise
self
.
model_options
=
model_options
def
_calculate_sigmas
(
self
,
steps
):
def
_calculate_sigmas
(
self
,
steps
):
sigmas
=
None
sigmas
=
None
...
@@ -418,7 +422,7 @@ class KSampler:
...
@@ -418,7 +422,7 @@ class KSampler:
else
:
else
:
precision_scope
=
contextlib
.
nullcontext
precision_scope
=
contextlib
.
nullcontext
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
}
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
,
"model_options"
:
self
.
model_options
}
cond_concat
=
None
cond_concat
=
None
if
hasattr
(
self
.
model
,
'concat_keys'
):
if
hasattr
(
self
.
model
,
'concat_keys'
):
...
@@ -467,7 +471,7 @@ class KSampler:
...
@@ -467,7 +471,7 @@ class KSampler:
x_T
=
z_enc
,
x_T
=
z_enc
,
x0
=
latent_image
,
x0
=
latent_image
,
denoise_function
=
sampling_function
,
denoise_function
=
sampling_function
,
cond_concat
=
cond_concat
,
extra_args
=
extra_args
,
mask
=
noise_mask
,
mask
=
noise_mask
,
to_zero
=
sigmas
[
-
1
]
==
0
,
to_zero
=
sigmas
[
-
1
]
==
0
,
end_step
=
sigmas
.
shape
[
0
]
-
1
)
end_step
=
sigmas
.
shape
[
0
]
-
1
)
...
...
comfy/sd.py
View file @
38833ceb
import
torch
import
torch
import
contextlib
import
contextlib
import
copy
import
sd1_clip
import
sd1_clip
import
sd2_clip
import
sd2_clip
...
@@ -274,12 +275,20 @@ class ModelPatcher:
...
@@ -274,12 +275,20 @@ class ModelPatcher:
self
.
model
=
model
self
.
model
=
model
self
.
patches
=
[]
self
.
patches
=
[]
self
.
backup
=
{}
self
.
backup
=
{}
self
.
model_options
=
{
"transformer_options"
:{}}
def
clone
(
self
):
def
clone
(
self
):
n
=
ModelPatcher
(
self
.
model
)
n
=
ModelPatcher
(
self
.
model
)
n
.
patches
=
self
.
patches
[:]
n
.
patches
=
self
.
patches
[:]
n
.
model_options
=
copy
.
deepcopy
(
self
.
model_options
)
return
n
return
n
def
set_model_tomesd
(
self
,
ratio
):
self
.
model_options
[
"transformer_options"
][
"tomesd"
]
=
{
"ratio"
:
ratio
}
def
model_dtype
(
self
):
return
self
.
model
.
diffusion_model
.
dtype
def
add_patches
(
self
,
patches
,
strength
=
1.0
):
def
add_patches
(
self
,
patches
,
strength
=
1.0
):
p
=
{}
p
=
{}
model_sd
=
self
.
model
.
state_dict
()
model_sd
=
self
.
model
.
state_dict
()
...
...
nodes.py
View file @
38833ceb
...
@@ -254,6 +254,22 @@ class LoraLoader:
...
@@ -254,6 +254,22 @@ class LoraLoader:
model_lora
,
clip_lora
=
comfy
.
sd
.
load_lora_for_models
(
model
,
clip
,
lora_path
,
strength_model
,
strength_clip
)
model_lora
,
clip_lora
=
comfy
.
sd
.
load_lora_for_models
(
model
,
clip
,
lora_path
,
strength_model
,
strength_clip
)
return
(
model_lora
,
clip_lora
)
return
(
model_lora
,
clip_lora
)
class
TomePatchModel
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"ratio"
:
(
"FLOAT"
,
{
"default"
:
0.3
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.01
}),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"patch"
CATEGORY
=
"_for_testing"
def
patch
(
self
,
model
,
ratio
):
m
=
model
.
clone
()
m
.
set_model_tomesd
(
ratio
)
return
(
m
,
)
class
VAELoader
:
class
VAELoader
:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
...
@@ -646,7 +662,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
...
@@ -646,7 +662,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
model_management
.
load_controlnet_gpu
(
control_net_models
)
model_management
.
load_controlnet_gpu
(
control_net_models
)
if
sampler_name
in
comfy
.
samplers
.
KSampler
.
SAMPLERS
:
if
sampler_name
in
comfy
.
samplers
.
KSampler
.
SAMPLERS
:
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
)
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
,
model_options
=
model
.
model_options
)
else
:
else
:
#other samplers
#other samplers
pass
pass
...
@@ -1016,6 +1032,7 @@ NODE_CLASS_MAPPINGS = {
...
@@ -1016,6 +1032,7 @@ NODE_CLASS_MAPPINGS = {
"CLIPVisionLoader"
:
CLIPVisionLoader
,
"CLIPVisionLoader"
:
CLIPVisionLoader
,
"VAEDecodeTiled"
:
VAEDecodeTiled
,
"VAEDecodeTiled"
:
VAEDecodeTiled
,
"VAEEncodeTiled"
:
VAEEncodeTiled
,
"VAEEncodeTiled"
:
VAEEncodeTiled
,
"TomePatchModel"
:
TomePatchModel
,
}
}
def
load_custom_node
(
module_path
):
def
load_custom_node
(
module_path
):
...
...
web/extensions/core/widgetInputs.js
View file @
38833ceb
...
@@ -101,7 +101,7 @@ app.registerExtension({
...
@@ -101,7 +101,7 @@ app.registerExtension({
callback
:
()
=>
convertToWidget
(
this
,
w
),
callback
:
()
=>
convertToWidget
(
this
,
w
),
});
});
}
else
{
}
else
{
const
config
=
nodeData
?.
input
?.
required
[
w
.
name
]
||
[
w
.
type
,
w
.
options
||
{}];
const
config
=
nodeData
?.
input
?.
required
[
w
.
name
]
||
nodeData
?.
input
?.
optional
?.[
w
.
name
]
||
[
w
.
type
,
w
.
options
||
{}];
if
(
isConvertableWidget
(
w
,
config
))
{
if
(
isConvertableWidget
(
w
,
config
))
{
toInput
.
push
({
toInput
.
push
({
content
:
`Convert
${
w
.
name
}
to input`
,
content
:
`Convert
${
w
.
name
}
to input`
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment