Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
fe4837a9
Unverified
Commit
fe4837a9
authored
Sep 14, 2023
by
YiYi Xu
Committed by
GitHub
Sep 14, 2023
Browse files
add step_index and clear noise_sampler at begining of each loop (#5024)
Co-authored-by:
yiyixuxu
<
yixu310@gmail,com
>
parent
342c5c02
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
7 deletions
+41
-7
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
+41
-7
No files found.
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
View file @
fe4837a9
...
@@ -199,6 +199,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -199,6 +199,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
noise_sampler
=
None
self
.
noise_sampler
=
None
self
.
noise_sampler_seed
=
noise_sampler_seed
self
.
noise_sampler_seed
=
noise_sampler_seed
self
.
_step_index
=
None
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
...
@@ -219,6 +220,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -219,6 +220,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return
indices
[
pos
].
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
index_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
@
property
@
property
def
init_noise_sigma
(
self
):
def
init_noise_sigma
(
self
):
# standard deviation of the initial noise distribution
# standard deviation of the initial noise distribution
...
@@ -227,6 +246,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -227,6 +246,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return
(
self
.
sigmas
.
max
()
**
2
+
1
)
**
0.5
return
(
self
.
sigmas
.
max
()
**
2
+
1
)
**
0.5
@
property
def
step_index
(
self
):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return
self
.
_step_index
def
scale_model_input
(
def
scale_model_input
(
self
,
self
,
sample
:
torch
.
FloatTensor
,
sample
:
torch
.
FloatTensor
,
...
@@ -246,9 +272,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -246,9 +272,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
`torch.FloatTensor`:
A scaled input sample.
A scaled input sample.
"""
"""
step_index
=
self
.
index_for_timestep
(
timestep
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
sigma
=
self
.
sigmas
[
step_index
]
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_input
=
sigma
if
self
.
state_in_first_order
else
self
.
mid_point_sigma
sigma_input
=
sigma
if
self
.
state_in_first_order
else
self
.
mid_point_sigma
sample
=
sample
/
((
sigma_input
**
2
+
1
)
**
0.5
)
sample
=
sample
/
((
sigma_input
**
2
+
1
)
**
0.5
)
return
sample
return
sample
...
@@ -321,6 +348,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -321,6 +348,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
self
.
sample
=
None
self
.
mid_point_sigma
=
None
self
.
mid_point_sigma
=
None
self
.
_step_index
=
None
self
.
noise_sampler
=
None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_index_counter
=
defaultdict
(
int
)
...
@@ -411,7 +441,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -411,7 +441,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
tuple is returned where the first element is the sample tensor.
"""
"""
step_index
=
self
.
index_for_timestep
(
timestep
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# advance index counter by 1
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
...
@@ -430,12 +461,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -430,12 +461,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return
_sigma
.
log
().
neg
()
return
_sigma
.
log
().
neg
()
if
self
.
state_in_first_order
:
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
step_index
]
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_next
=
self
.
sigmas
[
step_index
+
1
]
sigma_next
=
self
.
sigmas
[
self
.
step_index
+
1
]
else
:
else
:
# 2nd order
# 2nd order
sigma
=
self
.
sigmas
[
step_index
-
1
]
sigma
=
self
.
sigmas
[
self
.
step_index
-
1
]
sigma_next
=
self
.
sigmas
[
step_index
]
sigma_next
=
self
.
sigmas
[
self
.
step_index
]
# Set the midpoint and step size for the current step
# Set the midpoint and step size for the current step
midpoint_ratio
=
0.5
midpoint_ratio
=
0.5
...
@@ -488,6 +519,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -488,6 +519,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
self
.
sample
=
None
self
.
mid_point_sigma
=
None
self
.
mid_point_sigma
=
None
# upon completion increase step index by one
self
.
_step_index
+=
1
if
not
return_dict
:
if
not
return_dict
:
return
(
prev_sample
,)
return
(
prev_sample
,)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment