Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
e57cba4c
"...composable_kernel_onnx.git" did not exist on "40b59a63cc6308c01390e6ab07015a2f34a7b16a"
Commit
e57cba4c
authored
Jul 05, 2023
by
comfyanonymous
Browse files
Add gpu variations of the sde samplers that are less deterministic
but faster.
parent
f81b1929
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
8 deletions
+33
-8
comfy/k_diffusion/sampling.py
comfy/k_diffusion/sampling.py
+31
-6
comfy/samplers.py
comfy/samplers.py
+2
-2
No files found.
comfy/k_diffusion/sampling.py
View file @
e57cba4c
...
@@ -66,6 +66,9 @@ class BatchedBrownianTree:
...
@@ -66,6 +66,9 @@ class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def
__init__
(
self
,
x
,
t0
,
t1
,
seed
=
None
,
**
kwargs
):
def
__init__
(
self
,
x
,
t0
,
t1
,
seed
=
None
,
**
kwargs
):
self
.
cpu_tree
=
True
if
"cpu"
in
kwargs
:
self
.
cpu_tree
=
kwargs
.
pop
(
"cpu"
)
t0
,
t1
,
self
.
sign
=
self
.
sort
(
t0
,
t1
)
t0
,
t1
,
self
.
sign
=
self
.
sort
(
t0
,
t1
)
w0
=
kwargs
.
get
(
'w0'
,
torch
.
zeros_like
(
x
))
w0
=
kwargs
.
get
(
'w0'
,
torch
.
zeros_like
(
x
))
if
seed
is
None
:
if
seed
is
None
:
...
@@ -77,7 +80,10 @@ class BatchedBrownianTree:
...
@@ -77,7 +80,10 @@ class BatchedBrownianTree:
except
TypeError
:
except
TypeError
:
seed
=
[
seed
]
seed
=
[
seed
]
self
.
batched
=
False
self
.
batched
=
False
self
.
trees
=
[
torchsde
.
BrownianTree
(
t0
.
cpu
(),
w0
.
cpu
(),
t1
.
cpu
(),
entropy
=
s
,
**
kwargs
)
for
s
in
seed
]
if
self
.
cpu_tree
:
self
.
trees
=
[
torchsde
.
BrownianTree
(
t0
.
cpu
(),
w0
.
cpu
(),
t1
.
cpu
(),
entropy
=
s
,
**
kwargs
)
for
s
in
seed
]
else
:
self
.
trees
=
[
torchsde
.
BrownianTree
(
t0
,
w0
,
t1
,
entropy
=
s
,
**
kwargs
)
for
s
in
seed
]
@
staticmethod
@
staticmethod
def
sort
(
a
,
b
):
def
sort
(
a
,
b
):
...
@@ -85,7 +91,11 @@ class BatchedBrownianTree:
...
@@ -85,7 +91,11 @@ class BatchedBrownianTree:
def
__call__
(
self
,
t0
,
t1
):
def
__call__
(
self
,
t0
,
t1
):
t0
,
t1
,
sign
=
self
.
sort
(
t0
,
t1
)
t0
,
t1
,
sign
=
self
.
sort
(
t0
,
t1
)
w
=
torch
.
stack
([
tree
(
t0
.
cpu
().
float
(),
t1
.
cpu
().
float
()).
to
(
t0
.
dtype
).
to
(
t0
.
device
)
for
tree
in
self
.
trees
])
*
(
self
.
sign
*
sign
)
if
self
.
cpu_tree
:
w
=
torch
.
stack
([
tree
(
t0
.
cpu
().
float
(),
t1
.
cpu
().
float
()).
to
(
t0
.
dtype
).
to
(
t0
.
device
)
for
tree
in
self
.
trees
])
*
(
self
.
sign
*
sign
)
else
:
w
=
torch
.
stack
([
tree
(
t0
,
t1
)
for
tree
in
self
.
trees
])
*
(
self
.
sign
*
sign
)
return
w
if
self
.
batched
else
w
[
0
]
return
w
if
self
.
batched
else
w
[
0
]
...
@@ -104,10 +114,10 @@ class BrownianTreeNoiseSampler:
...
@@ -104,10 +114,10 @@ class BrownianTreeNoiseSampler:
internal timestep.
internal timestep.
"""
"""
def
__init__
(
self
,
x
,
sigma_min
,
sigma_max
,
seed
=
None
,
transform
=
lambda
x
:
x
):
def
__init__
(
self
,
x
,
sigma_min
,
sigma_max
,
seed
=
None
,
transform
=
lambda
x
:
x
,
cpu
=
False
):
self
.
transform
=
transform
self
.
transform
=
transform
t0
,
t1
=
self
.
transform
(
torch
.
as_tensor
(
sigma_min
)),
self
.
transform
(
torch
.
as_tensor
(
sigma_max
))
t0
,
t1
=
self
.
transform
(
torch
.
as_tensor
(
sigma_min
)),
self
.
transform
(
torch
.
as_tensor
(
sigma_max
))
self
.
tree
=
BatchedBrownianTree
(
x
,
t0
,
t1
,
seed
)
self
.
tree
=
BatchedBrownianTree
(
x
,
t0
,
t1
,
seed
,
cpu
=
cpu
)
def
__call__
(
self
,
sigma
,
sigma_next
):
def
__call__
(
self
,
sigma
,
sigma_next
):
t0
,
t1
=
self
.
transform
(
torch
.
as_tensor
(
sigma
)),
self
.
transform
(
torch
.
as_tensor
(
sigma_next
))
t0
,
t1
=
self
.
transform
(
torch
.
as_tensor
(
sigma
)),
self
.
transform
(
torch
.
as_tensor
(
sigma_next
))
...
@@ -544,7 +554,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
...
@@ -544,7 +554,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
"""DPM-Solver++ (stochastic)."""
"""DPM-Solver++ (stochastic)."""
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
seed
=
extra_args
.
get
(
"seed"
,
None
)
seed
=
extra_args
.
get
(
"seed"
,
None
)
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
seed
)
if
noise_sampler
is
None
else
noise_sampler
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
seed
,
cpu
=
True
)
if
noise_sampler
is
None
else
noise_sampler
extra_args
=
{}
if
extra_args
is
None
else
extra_args
extra_args
=
{}
if
extra_args
is
None
else
extra_args
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
sigma_fn
=
lambda
t
:
t
.
neg
().
exp
()
sigma_fn
=
lambda
t
:
t
.
neg
().
exp
()
...
@@ -616,7 +626,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
...
@@ -616,7 +626,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
seed
=
extra_args
.
get
(
"seed"
,
None
)
seed
=
extra_args
.
get
(
"seed"
,
None
)
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
seed
)
if
noise_sampler
is
None
else
noise_sampler
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
seed
,
cpu
=
True
)
if
noise_sampler
is
None
else
noise_sampler
extra_args
=
{}
if
extra_args
is
None
else
extra_args
extra_args
=
{}
if
extra_args
is
None
else
extra_args
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
...
@@ -651,3 +661,18 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
...
@@ -651,3 +661,18 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
old_denoised
=
denoised
old_denoised
=
denoised
h_last
=
h
h_last
=
h
return
x
return
x
@
torch
.
no_grad
()
def
sample_dpmpp_2m_sde_gpu
(
model
,
x
,
sigmas
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
,
eta
=
1.
,
s_noise
=
1.
,
noise_sampler
=
None
,
solver_type
=
'midpoint'
):
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
extra_args
.
get
(
"seed"
,
None
),
cpu
=
False
)
if
noise_sampler
is
None
else
noise_sampler
return
sample_dpmpp_2m_sde
(
model
,
x
,
sigmas
,
extra_args
=
extra_args
,
callback
=
callback
,
disable
=
disable
,
eta
=
eta
,
s_noise
=
s_noise
,
noise_sampler
=
noise_sampler
,
solver_type
=
solver_type
)
@
torch
.
no_grad
()
def
sample_dpmpp_sde_gpu
(
model
,
x
,
sigmas
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
,
eta
=
1.
,
s_noise
=
1.
,
noise_sampler
=
None
,
r
=
1
/
2
):
sigma_min
,
sigma_max
=
sigmas
[
sigmas
>
0
].
min
(),
sigmas
.
max
()
noise_sampler
=
BrownianTreeNoiseSampler
(
x
,
sigma_min
,
sigma_max
,
seed
=
extra_args
.
get
(
"seed"
,
None
),
cpu
=
False
)
if
noise_sampler
is
None
else
noise_sampler
return
sample_dpmpp_sde
(
model
,
x
,
sigmas
,
extra_args
=
extra_args
,
callback
=
callback
,
disable
=
disable
,
eta
=
eta
,
s_noise
=
s_noise
,
noise_sampler
=
noise_sampler
,
r
=
r
)
comfy/samplers.py
View file @
e57cba4c
...
@@ -483,8 +483,8 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
...
@@ -483,8 +483,8 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
class
KSampler
:
class
KSampler
:
SCHEDULERS
=
[
"normal"
,
"karras"
,
"exponential"
,
"simple"
,
"ddim_uniform"
]
SCHEDULERS
=
[
"normal"
,
"karras"
,
"exponential"
,
"simple"
,
"ddim_uniform"
]
SAMPLERS
=
[
"euler"
,
"euler_ancestral"
,
"heun"
,
"dpm_2"
,
"dpm_2_ancestral"
,
SAMPLERS
=
[
"euler"
,
"euler_ancestral"
,
"heun"
,
"dpm_2"
,
"dpm_2_ancestral"
,
"lms"
,
"dpm_fast"
,
"dpm_adaptive"
,
"dpmpp_2s_ancestral"
,
"dpmpp_sde"
,
"lms"
,
"dpm_fast"
,
"dpm_adaptive"
,
"dpmpp_2s_ancestral"
,
"dpmpp_sde"
,
"dpmpp_sde_gpu"
,
"dpmpp_2m"
,
"dpmpp_2m_sde"
,
"ddim"
,
"uni_pc"
,
"uni_pc_bh2"
]
"dpmpp_2m"
,
"dpmpp_2m_sde"
,
"dpmpp_2m_sde_gpu"
,
"ddim"
,
"uni_pc"
,
"uni_pc_bh2"
]
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
,
model_options
=
{}):
def
__init__
(
self
,
model
,
steps
,
device
,
sampler
=
None
,
scheduler
=
None
,
denoise
=
None
,
model_options
=
{}):
self
.
model
=
model
self
.
model
=
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment