Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
fe4837a9
Unverified
Commit
fe4837a9
authored
Sep 14, 2023
by
YiYi Xu
Committed by
GitHub
Sep 14, 2023
Browse files
add step_index and clear noise_sampler at begining of each loop (#5024)
Co-authored-by:
yiyixuxu
<
yixu310@gmail,com
>
parent
342c5c02
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
7 deletions
+41
-7
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
+41
-7
No files found.
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
View file @
fe4837a9
...
@@ -199,6 +199,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -199,6 +199,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
noise_sampler
=
None
self
.
noise_sampler
=
None
self
.
noise_sampler_seed
=
noise_sampler_seed
self
.
noise_sampler_seed
=
noise_sampler_seed
self
.
_step_index
=
None
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
...
@@ -219,6 +220,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -219,6 +220,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return
indices
[
pos
].
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
index_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
@
property
@
property
def
init_noise_sigma
(
self
):
def
init_noise_sigma
(
self
):
# standard deviation of the initial noise distribution
# standard deviation of the initial noise distribution
...
@@ -227,6 +246,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -227,6 +246,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return
(
self
.
sigmas
.
max
()
**
2
+
1
)
**
0.5
return
(
self
.
sigmas
.
max
()
**
2
+
1
)
**
0.5
@
property
def
step_index
(
self
):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return
self
.
_step_index
def
scale_model_input
(
def
scale_model_input
(
self
,
self
,
sample
:
torch
.
FloatTensor
,
sample
:
torch
.
FloatTensor
,
...
@@ -246,9 +272,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -246,9 +272,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
`torch.FloatTensor`:
A scaled input sample.
A scaled input sample.
"""
"""
step_index
=
self
.
index_for_timestep
(
timestep
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
sigma
=
self
.
sigmas
[
step_index
]
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_input
=
sigma
if
self
.
state_in_first_order
else
self
.
mid_point_sigma
sigma_input
=
sigma
if
self
.
state_in_first_order
else
self
.
mid_point_sigma
sample
=
sample
/
((
sigma_input
**
2
+
1
)
**
0.5
)
sample
=
sample
/
((
sigma_input
**
2
+
1
)
**
0.5
)
return
sample
return
sample
...
@@ -321,6 +348,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -321,6 +348,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
self
.
sample
=
None
self
.
mid_point_sigma
=
None
self
.
mid_point_sigma
=
None
self
.
_step_index
=
None
self
.
noise_sampler
=
None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_index_counter
=
defaultdict
(
int
)
...
@@ -411,7 +441,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -411,7 +441,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
tuple is returned where the first element is the sample tensor.
"""
"""
step_index
=
self
.
index_for_timestep
(
timestep
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# advance index counter by 1
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
...
@@ -430,12 +461,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -430,12 +461,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return
_sigma
.
log
().
neg
()
return
_sigma
.
log
().
neg
()
if
self
.
state_in_first_order
:
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
step_index
]
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_next
=
self
.
sigmas
[
step_index
+
1
]
sigma_next
=
self
.
sigmas
[
self
.
step_index
+
1
]
else
:
else
:
# 2nd order
# 2nd order
sigma
=
self
.
sigmas
[
step_index
-
1
]
sigma
=
self
.
sigmas
[
self
.
step_index
-
1
]
sigma_next
=
self
.
sigmas
[
step_index
]
sigma_next
=
self
.
sigmas
[
self
.
step_index
]
# Set the midpoint and step size for the current step
# Set the midpoint and step size for the current step
midpoint_ratio
=
0.5
midpoint_ratio
=
0.5
...
@@ -488,6 +519,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -488,6 +519,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
self
.
sample
=
None
self
.
mid_point_sigma
=
None
self
.
mid_point_sigma
=
None
# upon completion increase step index by one
self
.
_step_index
+=
1
if
not
return_dict
:
if
not
return_dict
:
return
(
prev_sample
,)
return
(
prev_sample
,)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment