Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
12c1080e
"lightx2v_platform/vscode:/vscode.git/clone" did not exist on "5546f759a477bd34310591102037c8cd0afd5faa"
Commit
12c1080e
authored
Mar 03, 2024
by
comfyanonymous
Browse files
Simplify differential diffusion code.
parent
727021bd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
74 deletions
+23
-74
comfy/model_patcher.py
comfy/model_patcher.py
+3
-0
comfy/samplers.py
comfy/samplers.py
+4
-3
comfy_extras/nodes_differential_diffusion.py
comfy_extras/nodes_differential_diffusion.py
+16
-71
No files found.
comfy/model_patcher.py
View file @
12c1080e
...
...
@@ -67,6 +67,9 @@ class ModelPatcher:
def
set_model_unet_function_wrapper
(
self
,
unet_wrapper_function
):
self
.
model_options
[
"model_function_wrapper"
]
=
unet_wrapper_function
def
set_model_denoise_mask_function
(
self
,
denoise_mask_function
):
self
.
model_options
[
"denoise_mask_function"
]
=
denoise_mask_function
def
set_model_patch
(
self
,
patch
,
name
):
to
=
self
.
model_options
[
"transformer_options"
]
if
"patches"
not
in
to
:
...
...
comfy/samplers.py
View file @
12c1080e
...
...
@@ -272,13 +272,14 @@ class CFGNoisePredictor(torch.nn.Module):
return
self
.
apply_model
(
*
args
,
**
kwargs
)
class
KSamplerX0Inpaint
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
,
sigmas
):
super
().
__init__
()
self
.
inner_model
=
model
self
.
sigmas
=
sigmas
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
denoise_mask
,
model_options
=
{},
seed
=
None
):
if
denoise_mask
is
not
None
:
if
"denoise_mask_function"
in
model_options
:
denoise_mask
=
model_options
[
"denoise_mask_function"
](
sigma
,
denoise_mask
)
denoise_mask
=
model_options
[
"denoise_mask_function"
](
sigma
,
denoise_mask
,
extra_options
=
{
"model"
:
self
.
inner_model
,
"sigmas"
:
self
.
sigmas
}
)
latent_mask
=
1.
-
denoise_mask
x
=
x
*
denoise_mask
+
self
.
inner_model
.
inner_model
.
model_sampling
.
noise_scaling
(
sigma
.
reshape
([
sigma
.
shape
[
0
]]
+
[
1
]
*
(
len
(
self
.
noise
.
shape
)
-
1
)),
self
.
noise
,
self
.
latent_image
)
*
latent_mask
out
=
self
.
inner_model
(
x
,
sigma
,
cond
=
cond
,
uncond
=
uncond
,
cond_scale
=
cond_scale
,
model_options
=
model_options
,
seed
=
seed
)
...
...
@@ -528,7 +529,7 @@ class KSAMPLER(Sampler):
def
sample
(
self
,
model_wrap
,
sigmas
,
extra_args
,
callback
,
noise
,
latent_image
=
None
,
denoise_mask
=
None
,
disable_pbar
=
False
):
extra_args
[
"denoise_mask"
]
=
denoise_mask
model_k
=
KSamplerX0Inpaint
(
model_wrap
)
model_k
=
KSamplerX0Inpaint
(
model_wrap
,
sigmas
)
model_k
.
latent_image
=
latent_image
if
self
.
inpaint_options
.
get
(
"random"
,
False
):
#TODO: Should this be the default?
generator
=
torch
.
manual_seed
(
extra_args
.
get
(
"seed"
,
41
)
+
1
)
...
...
comfy_extras/nodes_differential_diffusion.py
View file @
12c1080e
# code adapted from https://github.com/exx8/differential-diffusion
import
torch
import
inspect
class
DifferentialDiffusion
():
@
classmethod
...
...
@@ -13,80 +12,26 @@ class DifferentialDiffusion():
CATEGORY
=
"_for_testing"
INIT
=
False
@
classmethod
def
IS_CHANGED
(
s
,
*
args
,
**
kwargs
):
DifferentialDiffusion
.
INIT
=
s
.
INIT
=
True
return
""
def
__init__
(
self
)
->
None
:
DifferentialDiffusion
.
INIT
=
False
self
.
sigmas
:
torch
.
Tensor
=
None
self
.
thresholds
:
torch
.
Tensor
=
None
self
.
mask_i
=
None
self
.
valid_sigmas
=
False
self
.
varying_sigmas_samplers
=
[
"dpmpp_2s"
,
"dpmpp_sde"
,
"dpm_2"
,
"heun"
,
"restart"
]
def
apply
(
self
,
model
):
model
=
model
.
clone
()
model
.
model_
options
[
"
denoise_mask_function
"
]
=
self
.
forward
model
.
set_
model_denoise_mask_function
(
self
.
forward
)
return
(
model
,)
def
init_sigmas
(
self
,
sigma
:
torch
.
Tensor
,
denoise_mask
:
torch
.
Tensor
):
self
.
__init__
()
self
.
sigmas
,
sampler
=
find_outer_instance
(
"sigmas"
,
callback
=
get_sigmas_and_sampler
)
or
(
None
,
""
)
self
.
valid_sigmas
=
not
(
"sample_"
not
in
sampler
or
any
(
s
in
sampler
for
s
in
self
.
varying_sigmas_samplers
))
or
"generic"
in
sampler
if
self
.
sigmas
is
None
:
self
.
sigmas
=
sigma
[:
1
].
repeat
(
2
)
self
.
sigmas
[
-
1
].
zero_
()
self
.
sigmas_min
=
self
.
sigmas
.
min
()
self
.
sigmas_max
=
self
.
sigmas
.
max
()
self
.
thresholds
=
torch
.
linspace
(
1
,
0
,
self
.
sigmas
.
shape
[
0
],
dtype
=
sigma
.
dtype
,
device
=
sigma
.
device
)
self
.
thresholds_min_len
=
self
.
thresholds
.
shape
[
0
]
-
1
if
self
.
valid_sigmas
:
thresholds
=
self
.
thresholds
[:
-
1
].
reshape
(
-
1
,
1
,
1
,
1
,
1
)
mask
=
denoise_mask
.
unsqueeze
(
0
)
mask
=
(
mask
>=
thresholds
).
to
(
denoise_mask
.
dtype
)
self
.
mask_i
=
iter
(
mask
)
def
forward
(
self
,
sigma
:
torch
.
Tensor
,
denoise_mask
:
torch
.
Tensor
,
extra_options
:
dict
):
model
=
extra_options
[
"model"
]
step_sigmas
=
extra_options
[
"sigmas"
]
sigma_to
=
model
.
inner_model
.
model_sampling
.
sigma_min
if
step_sigmas
[
-
1
]
>
sigma_to
:
sigma_to
=
step_sigmas
[
-
1
]
sigma_from
=
step_sigmas
[
0
]
def
forward
(
self
,
sigma
:
torch
.
Tensor
,
denoise_mask
:
torch
.
Tensor
):
if
self
.
sigmas
is
None
or
DifferentialDiffusion
.
INIT
:
self
.
init_sigmas
(
sigma
,
denoise_mask
)
if
self
.
valid_sigmas
:
try
:
return
next
(
self
.
mask_i
)
except
StopIteration
:
self
.
valid_sigmas
=
False
if
self
.
thresholds_min_len
>
1
:
nearest_idx
=
(
self
.
sigmas
-
sigma
[
0
]).
abs
().
argmin
()
if
not
self
.
thresholds_min_len
>
nearest_idx
:
nearest_idx
=
-
2
threshold
=
self
.
thresholds
[
nearest_idx
]
else
:
threshold
=
(
sigma
[
0
]
-
self
.
sigmas_min
)
/
(
self
.
sigmas_max
-
self
.
sigmas_min
)
return
(
denoise_mask
>=
threshold
).
to
(
denoise_mask
.
dtype
)
ts_from
=
model
.
inner_model
.
model_sampling
.
timestep
(
sigma_from
)
ts_to
=
model
.
inner_model
.
model_sampling
.
timestep
(
sigma_to
)
current_ts
=
model
.
inner_model
.
model_sampling
.
timestep
(
sigma
)
def
get_sigmas_and_sampler
(
frame
,
target
):
found
=
frame
.
f_locals
[
target
]
if
isinstance
(
found
,
torch
.
Tensor
)
and
found
[
-
1
]
<
0.1
:
return
found
,
frame
.
f_code
.
co_name
return
False
threshold
=
(
current_ts
-
ts_to
)
/
(
ts_from
-
ts_to
)
def
find_outer_instance
(
target
:
str
,
target_type
=
None
,
callback
=
None
):
frame
=
inspect
.
currentframe
()
i
=
0
while
frame
and
i
<
100
:
if
target
in
frame
.
f_locals
:
if
callback
is
not
None
:
res
=
callback
(
frame
,
target
)
if
res
:
return
res
else
:
found
=
frame
.
f_locals
[
target
]
if
isinstance
(
found
,
target_type
):
return
found
frame
=
frame
.
f_back
i
+=
1
return
None
return
(
denoise_mask
>=
threshold
).
to
(
denoise_mask
.
dtype
)
NODE_CLASS_MAPPINGS
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment