Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
4eab00e1
Commit
4eab00e1
authored
Jun 25, 2023
by
comfyanonymous
Browse files
Set the seed in the SDE samplers to make them more reproducible.
parent
cef6aa62
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
16 additions
and
14 deletions
+16
-14
comfy/k_diffusion/sampling.py
comfy/k_diffusion/sampling.py
+6
-4
comfy/sample.py
comfy/sample.py
+2
-2
comfy/samplers.py
comfy/samplers.py
+7
-7
nodes.py
nodes.py
+1
-1
No files found.
comfy/k_diffusion/sampling.py
View file @
4eab00e1
...
@@ -77,7 +77,7 @@ class BatchedBrownianTree:
...
@@ -77,7 +77,7 @@ class BatchedBrownianTree:
except
TypeError
:
except
TypeError
:
seed
=
[
seed
]
seed
=
[
seed
]
self
.
batched
=
False
self
.
batched
=
False
self
.
trees
=
[
torchsde
.
BrownianTree
(
t0
,
w0
,
t1
,
entropy
=
s
,
**
kwargs
)
for
s
in
seed
]
self
.
trees
=
[
torchsde
.
BrownianTree
(
t0
.
cpu
(),
w0
.
cpu
(),
t1
.
cpu
()
,
entropy
=
s
,
**
kwargs
)
for
s
in
seed
]
@
staticmethod
@
staticmethod
def
sort
(
a
,
b
):
def
sort
(
a
,
b
):
...
@@ -85,7 +85,7 @@ class BatchedBrownianTree:
...
@@ -85,7 +85,7 @@ class BatchedBrownianTree:
def
__call__
(
self
,
t0
,
t1
):
def
__call__
(
self
,
t0
,
t1
):
t0
,
t1
,
sign
=
self
.
sort
(
t0
,
t1
)
t0
,
t1
,
sign
=
self
.
sort
(
t0
,
t1
)
w
=
torch
.
stack
([
tree
(
t0
,
t1
)
for
tree
in
self
.
trees
])
*
(
self
.
sign
*
sign
)
w
=
torch
.
stack
([
tree
(
t0
.
cpu
().
float
(),
t1
.
cpu
().
float
()).
to
(
t0
.
dtype
).
to
(
t0
.
device
)
for
tree
in
self
.
trees
])
*
(
self
.
sign
*
sign
)
return
w
if
self
.
batched
else
w
[
0
]
return
w
if
self
.
batched
else
w
[
0
]
...
@@ -543,7 +543,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
...
@@ -543,7 +543,8 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
def
sample_dpmpp_sde
(
model
,
x
,
sigmas
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
,
eta
=
1.
,
s_noise
=
1.
,
noise_sampler
=
None
,
r
=
1
/
2
):
def
sample_dpmpp_sde
(
model
,
x
,
sigmas
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
,
eta
=
1.
,
s_noise
=
1.
,
noise_sampler
=
None
,
r
=
1
/
2
):
"""DPM-Solver++ (stochastic)."""
"""DPM-Solver++ (stochastic)."""
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
)
if
noise_sampler
is
None
else
noise_sampler
seed
=
extra_args
.
get
(
"seed"
,
None
)
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
seed
)
if
noise_sampler
is
None
else
noise_sampler
extra_args
=
{}
if
extra_args
is
None
else
extra_args
extra_args
=
{}
if
extra_args
is
None
else
extra_args
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
sigma_fn
=
lambda
t
:
t
.
neg
().
exp
()
sigma_fn
=
lambda
t
:
t
.
neg
().
exp
()
...
@@ -613,8 +614,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
...
@@ -613,8 +614,9 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if
solver_type
not
in
{
'heun'
,
'midpoint'
}:
if
solver_type
not
in
{
'heun'
,
'midpoint'
}:
raise
ValueError
(
'solver_type must be
\'
heun
\'
or
\'
midpoint
\'
'
)
raise
ValueError
(
'solver_type must be
\'
heun
\'
or
\'
midpoint
\'
'
)
seed
=
extra_args
.
get
(
"seed"
,
None
)
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
)
if
noise_sampler
is
None
else
noise_sampler
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
seed
)
if
noise_sampler
is
None
else
noise_sampler
extra_args
=
{}
if
extra_args
is
None
else
extra_args
extra_args
=
{}
if
extra_args
is
None
else
extra_args
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
...
...
comfy/sample.py
View file @
4eab00e1
...
@@ -65,7 +65,7 @@ def cleanup_additional_models(models):
...
@@ -65,7 +65,7 @@ def cleanup_additional_models(models):
for
m
in
models
:
for
m
in
models
:
m
.
cleanup
()
m
.
cleanup
()
def
sample
(
model
,
noise
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
noise_mask
=
None
,
sigmas
=
None
,
callback
=
None
,
disable_pbar
=
False
):
def
sample
(
model
,
noise
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
noise_mask
=
None
,
sigmas
=
None
,
callback
=
None
,
disable_pbar
=
False
,
seed
=
None
):
device
=
comfy
.
model_management
.
get_torch_device
()
device
=
comfy
.
model_management
.
get_torch_device
()
if
noise_mask
is
not
None
:
if
noise_mask
is
not
None
:
...
@@ -85,7 +85,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
...
@@ -85,7 +85,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
,
model_options
=
model
.
model_options
)
sampler
=
comfy
.
samplers
.
KSampler
(
real_model
,
steps
=
steps
,
device
=
device
,
sampler
=
sampler_name
,
scheduler
=
scheduler
,
denoise
=
denoise
,
model_options
=
model
.
model_options
)
samples
=
sampler
.
sample
(
noise
,
positive_copy
,
negative_copy
,
cfg
=
cfg
,
latent_image
=
latent_image
,
start_step
=
start_step
,
last_step
=
last_step
,
force_full_denoise
=
force_full_denoise
,
denoise_mask
=
noise_mask
,
sigmas
=
sigmas
,
callback
=
callback
,
disable_pbar
=
disable_pbar
)
samples
=
sampler
.
sample
(
noise
,
positive_copy
,
negative_copy
,
cfg
=
cfg
,
latent_image
=
latent_image
,
start_step
=
start_step
,
last_step
=
last_step
,
force_full_denoise
=
force_full_denoise
,
denoise_mask
=
noise_mask
,
sigmas
=
sigmas
,
callback
=
callback
,
disable_pbar
=
disable_pbar
,
seed
=
seed
)
samples
=
samples
.
cpu
()
samples
=
samples
.
cpu
()
cleanup_additional_models
(
models
)
cleanup_additional_models
(
models
)
...
...
comfy/samplers.py
View file @
4eab00e1
...
@@ -13,7 +13,7 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
...
@@ -13,7 +13,7 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
#The main sampling function shared by all the samplers
#The main sampling function shared by all the samplers
#Returns predicted noise
#Returns predicted noise
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
=
None
,
model_options
=
{}):
def
sampling_function
(
model_function
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
=
None
,
model_options
=
{}
,
seed
=
None
):
def
get_area_and_mult
(
cond
,
x_in
,
cond_concat_in
,
timestep_in
):
def
get_area_and_mult
(
cond
,
x_in
,
cond_concat_in
,
timestep_in
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
strength
=
1.0
...
@@ -292,8 +292,8 @@ class CFGNoisePredictor(torch.nn.Module):
...
@@ -292,8 +292,8 @@ class CFGNoisePredictor(torch.nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
inner_model
=
model
self
.
inner_model
=
model
self
.
alphas_cumprod
=
model
.
alphas_cumprod
self
.
alphas_cumprod
=
model
.
alphas_cumprod
def
apply_model
(
self
,
x
,
timestep
,
cond
,
uncond
,
cond_scale
,
cond_concat
=
None
,
model_options
=
{}):
def
apply_model
(
self
,
x
,
timestep
,
cond
,
uncond
,
cond_scale
,
cond_concat
=
None
,
model_options
=
{}
,
seed
=
None
):
out
=
sampling_function
(
self
.
inner_model
.
apply_model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
,
model_options
=
model_options
)
out
=
sampling_function
(
self
.
inner_model
.
apply_model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
cond_concat
,
model_options
=
model_options
,
seed
=
seed
)
return
out
return
out
...
@@ -301,11 +301,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
...
@@ -301,11 +301,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
super
().
__init__
()
super
().
__init__
()
self
.
inner_model
=
model
self
.
inner_model
=
model
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
denoise_mask
,
cond_concat
=
None
,
model_options
=
{}):
def
forward
(
self
,
x
,
sigma
,
uncond
,
cond
,
cond_scale
,
denoise_mask
,
cond_concat
=
None
,
model_options
=
{}
,
seed
=
None
):
if
denoise_mask
is
not
None
:
if
denoise_mask
is
not
None
:
latent_mask
=
1.
-
denoise_mask
latent_mask
=
1.
-
denoise_mask
x
=
x
*
denoise_mask
+
(
self
.
latent_image
+
self
.
noise
*
sigma
.
reshape
([
sigma
.
shape
[
0
]]
+
[
1
]
*
(
len
(
self
.
noise
.
shape
)
-
1
)))
*
latent_mask
x
=
x
*
denoise_mask
+
(
self
.
latent_image
+
self
.
noise
*
sigma
.
reshape
([
sigma
.
shape
[
0
]]
+
[
1
]
*
(
len
(
self
.
noise
.
shape
)
-
1
)))
*
latent_mask
out
=
self
.
inner_model
(
x
,
sigma
,
cond
=
cond
,
uncond
=
uncond
,
cond_scale
=
cond_scale
,
cond_concat
=
cond_concat
,
model_options
=
model_options
)
out
=
self
.
inner_model
(
x
,
sigma
,
cond
=
cond
,
uncond
=
uncond
,
cond_scale
=
cond_scale
,
cond_concat
=
cond_concat
,
model_options
=
model_options
,
seed
=
seed
)
if
denoise_mask
is
not
None
:
if
denoise_mask
is
not
None
:
out
*=
denoise_mask
out
*=
denoise_mask
...
@@ -542,7 +542,7 @@ class KSampler:
...
@@ -542,7 +542,7 @@ class KSampler:
sigmas
=
self
.
calculate_sigmas
(
new_steps
).
to
(
self
.
device
)
sigmas
=
self
.
calculate_sigmas
(
new_steps
).
to
(
self
.
device
)
self
.
sigmas
=
sigmas
[
-
(
steps
+
1
):]
self
.
sigmas
=
sigmas
[
-
(
steps
+
1
):]
def
sample
(
self
,
noise
,
positive
,
negative
,
cfg
,
latent_image
=
None
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
denoise_mask
=
None
,
sigmas
=
None
,
callback
=
None
,
disable_pbar
=
False
):
def
sample
(
self
,
noise
,
positive
,
negative
,
cfg
,
latent_image
=
None
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
denoise_mask
=
None
,
sigmas
=
None
,
callback
=
None
,
disable_pbar
=
False
,
seed
=
None
):
if
sigmas
is
None
:
if
sigmas
is
None
:
sigmas
=
self
.
sigmas
sigmas
=
self
.
sigmas
sigma_min
=
self
.
sigma_min
sigma_min
=
self
.
sigma_min
...
@@ -589,7 +589,7 @@ class KSampler:
...
@@ -589,7 +589,7 @@ class KSampler:
if
latent_image
is
not
None
:
if
latent_image
is
not
None
:
latent_image
=
self
.
model
.
process_latent_in
(
latent_image
)
latent_image
=
self
.
model
.
process_latent_in
(
latent_image
)
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
,
"model_options"
:
self
.
model_options
}
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
,
"model_options"
:
self
.
model_options
,
"seed"
:
seed
}
cond_concat
=
None
cond_concat
=
None
if
hasattr
(
self
.
model
,
'concat_keys'
):
#inpaint
if
hasattr
(
self
.
model
,
'concat_keys'
):
#inpaint
...
...
nodes.py
View file @
4eab00e1
...
@@ -965,7 +965,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
...
@@ -965,7 +965,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
samples
=
comfy
.
sample
.
sample
(
model
,
noise
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
samples
=
comfy
.
sample
.
sample
(
model
,
noise
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
denoise
,
disable_noise
=
disable_noise
,
start_step
=
start_step
,
last_step
=
last_step
,
denoise
=
denoise
,
disable_noise
=
disable_noise
,
start_step
=
start_step
,
last_step
=
last_step
,
force_full_denoise
=
force_full_denoise
,
noise_mask
=
noise_mask
,
callback
=
callback
)
force_full_denoise
=
force_full_denoise
,
noise_mask
=
noise_mask
,
callback
=
callback
,
seed
=
seed
)
out
=
latent
.
copy
()
out
=
latent
.
copy
()
out
[
"samples"
]
=
samples
out
[
"samples"
]
=
samples
return
(
out
,
)
return
(
out
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment