Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
e57cba4c
Commit
e57cba4c
authored
Jul 05, 2023
by
comfyanonymous
Browse files
Add gpu variations of the sde samplers that are less deterministic
but faster.
parent
f81b1929
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
8 deletions
+33
-8
comfy/k_diffusion/sampling.py
comfy/k_diffusion/sampling.py
+31
-6
comfy/samplers.py
comfy/samplers.py
+2
-2
No files found.
comfy/k_diffusion/sampling.py
View file @
e57cba4c
...
...
@@ -66,6 +66,9 @@ class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def
__init__
(
self
,
x
,
t0
,
t1
,
seed
=
None
,
**
kwargs
):
self
.
cpu_tree
=
True
if
"cpu"
in
kwargs
:
self
.
cpu_tree
=
kwargs
.
pop
(
"cpu"
)
t0
,
t1
,
self
.
sign
=
self
.
sort
(
t0
,
t1
)
w0
=
kwargs
.
get
(
'w0'
,
torch
.
zeros_like
(
x
))
if
seed
is
None
:
...
...
@@ -77,7 +80,10 @@ class BatchedBrownianTree:
except
TypeError
:
seed
=
[
seed
]
self
.
batched
=
False
self
.
trees
=
[
torchsde
.
BrownianTree
(
t0
.
cpu
(),
w0
.
cpu
(),
t1
.
cpu
(),
entropy
=
s
,
**
kwargs
)
for
s
in
seed
]
if
self
.
cpu_tree
:
self
.
trees
=
[
torchsde
.
BrownianTree
(
t0
.
cpu
(),
w0
.
cpu
(),
t1
.
cpu
(),
entropy
=
s
,
**
kwargs
)
for
s
in
seed
]
else
:
self
.
trees
=
[
torchsde
.
BrownianTree
(
t0
,
w0
,
t1
,
entropy
=
s
,
**
kwargs
)
for
s
in
seed
]
@
staticmethod
def
sort
(
a
,
b
):
...
...
@@ -85,7 +91,11 @@ class BatchedBrownianTree:
def
__call__
(
self
,
t0
,
t1
):
t0
,
t1
,
sign
=
self
.
sort
(
t0
,
t1
)
w
=
torch
.
stack
([
tree
(
t0
.
cpu
().
float
(),
t1
.
cpu
().
float
()).
to
(
t0
.
dtype
).
to
(
t0
.
device
)
for
tree
in
self
.
trees
])
*
(
self
.
sign
*
sign
)
if
self
.
cpu_tree
:
w
=
torch
.
stack
([
tree
(
t0
.
cpu
().
float
(),
t1
.
cpu
().
float
()).
to
(
t0
.
dtype
).
to
(
t0
.
device
)
for
tree
in
self
.
trees
])
*
(
self
.
sign
*
sign
)
else
:
w
=
torch
.
stack
([
tree
(
t0
,
t1
)
for
tree
in
self
.
trees
])
*
(
self
.
sign
*
sign
)
return
w
if
self
.
batched
else
w
[
0
]
...
...
@@ -104,10 +114,10 @@ class BrownianTreeNoiseSampler:
internal timestep.
"""
def
__init__
(
self
,
x
,
sigma_min
,
sigma_max
,
seed
=
None
,
transform
=
lambda
x
:
x
):
def
__init__
(
self
,
x
,
sigma_min
,
sigma_max
,
seed
=
None
,
transform
=
lambda
x
:
x
,
cpu
=
False
):
self
.
transform
=
transform
t0
,
t1
=
self
.
transform
(
torch
.
as_tensor
(
sigma_min
)),
self
.
transform
(
torch
.
as_tensor
(
sigma_max
))
self
.
tree
=
BatchedBrownianTree
(
x
,
t0
,
t1
,
seed
)
self
.
tree
=
BatchedBrownianTree
(
x
,
t0
,
t1
,
seed
,
cpu
=
cpu
)
def
__call__
(
self
,
sigma
,
sigma_next
):
t0
,
t1
=
self
.
transform
(
torch
.
as_tensor
(
sigma
)),
self
.
transform
(
torch
.
as_tensor
(
sigma_next
))
...
...
@@ -544,7 +554,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
"""DPM-Solver++ (stochastic)."""
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
seed
=
extra_args
.
get
(
"seed"
,
None
)
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
seed
)
if
noise_sampler
is
None
else
noise_sampler
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
seed
,
cpu
=
True
)
if
noise_sampler
is
None
else
noise_sampler
extra_args
=
{}
if
extra_args
is
None
else
extra_args
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
sigma_fn
=
lambda
t
:
t
.
neg
().
exp
()
...
...
@@ -616,7 +626,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
seed
=
extra_args
.
get
(
"seed"
,
None
)
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
seed
)
if
noise_sampler
is
None
else
noise_sampler
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
seed
,
cpu
=
True
)
if
noise_sampler
is
None
else
noise_sampler
extra_args
=
{}
if
extra_args
is
None
else
extra_args
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
...
...
@@ -651,3 +661,18 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
old_denoised
=
denoised
h_last
=
h
return
x
@
torch
.
no_grad
()
def
sample_dpmpp_2m_sde_gpu
(
model
,
x
,
sigmas
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
,
eta
=
1.
,
s_noise
=
1.
,
noise_sampler
=
None
,
solver_type
=
'midpoint'
):
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
extra_args
.
get
(
"seed"
,
None
),
cpu
=
False
)
if
noise_sampler
is
None
else
noise_sampler
return
sample_dpmpp_2m_sde
(
model
,
x
,
sigmas
,
extra_args
=
extra_args
,
callback
=
callback
,
disable
=
disable
,
eta
=
eta
,
s_noise
=
s_noise
,
noise_sampler
=
noise_sampler
,
solver_type
=
solver_type
)
@
torch
.
no_grad
()
def
sample_dpmpp_sde_gpu
(
model
,
x
,
sigmas
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
,
eta
=
1.
,
s_noise
=
1.
,
noise_sampler
=
None
,
r
=
1
/
2
):
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
extra_args
.
get
(
"seed"
,
None
),
cpu
=
False
)
if
noise_sampler
is
None
else
noise_sampler
return
sample_dpmpp_sde
(
model
,
x
,
sigmas
,
extra_args
=
extra_args
,
callback
=
callback
,
disable
=
disable
,
eta
=
eta
,
s_noise
=
s_noise
,
noise_sampler
=
noise_sampler
,
r
=
r
)
comfy/samplers.py
View file @
e57cba4c
...
...
@@ -483,8 +483,8 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
class
KSampler
:
SCHEDULERS
=
[
"normal"
,
"karras"
,
"exponential"
,
"simple"
,
"ddim_uniform"
]
SAMPLERS
=
[
"euler"
,
"euler_ancestral"
,
"heun"
,
"dpm_2"
,
"dpm_2_ancestral"
,
"lms"
,
"dpm_fast"
,
"dpm_adaptive"
,
"dpmpp_2s_ancestral"
,
"dpmpp_sde"
,
"dpmpp_2m"
,
"dpmpp_2m_sde"
,
"ddim"
,
"uni_pc"
,
"uni_pc_bh2"
]
"lms"
,
"dpm_fast"
,
"dpm_adaptive"
,
"dpmpp_2s_ancestral"
,
"dpmpp_sde"
,
"dpmpp_sde_gpu"
,
"dpmpp_2m"
,
"dpmpp_2m_sde"
,
"dpmpp_2m_sde_gpu"
,
"ddim"
,
"uni_pc"
,
"uni_pc_bh2"
]
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
,
model_options
=
{}):
self
.
model
=
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment