Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
adcbe674
Unverified
Commit
adcbe674
authored
Feb 01, 2024
by
YiYi Xu
Committed by
GitHub
Feb 01, 2024
Browse files
[refactor]Scheduler.set_begin_index (#6728)
parent
ec9840a5
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
313 additions
and
162 deletions
+313
-162
src/diffusers/schedulers/scheduling_heun_discrete.py
src/diffusers/schedulers/scheduling_heun_discrete.py
+33
-31
src/diffusers/schedulers/scheduling_ipndm.py
src/diffusers/schedulers/scheduling_ipndm.py
+36
-10
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
...users/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+42
-40
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+42
-40
src/diffusers/schedulers/scheduling_lcm.py
src/diffusers/schedulers/scheduling_lcm.py
+36
-10
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+41
-11
src/diffusers/schedulers/scheduling_sasolver.py
src/diffusers/schedulers/scheduling_sasolver.py
+39
-5
src/diffusers/schedulers/scheduling_unipc_multistep.py
src/diffusers/schedulers/scheduling_unipc_multistep.py
+44
-15
No files found.
src/diffusers/schedulers/scheduling_heun_discrete.py
View file @
adcbe674
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
math
from
collections
import
defaultdict
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -148,8 +147,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
...
...
@@ -160,11 +161,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
self
.
_index_counter
)
==
0
:
pos
=
1
if
len
(
indices
)
>
1
else
0
else
:
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
pos
=
self
.
_index_counter
[
timestep_int
]
pos
=
1
if
len
(
indices
)
>
1
else
0
return
indices
[
pos
].
item
()
...
...
@@ -183,6 +180,24 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
...
...
@@ -270,13 +285,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
dt
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
# get log sigma
...
...
@@ -333,21 +344,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
index_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
...
...
@@ -378,11 +380,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# (YiYi notes: keep this for now since we are keeping the add_noise method)
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
self
.
_index_counter
[
timestep_int
]
+=
1
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_next
=
self
.
sigmas
[
self
.
step_index
+
1
]
...
...
@@ -453,6 +450,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -469,7 +467,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_ipndm.py
View file @
adcbe674
...
...
@@ -56,6 +56,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
# running values
self
.
ets
=
[]
self
.
_step_index
=
None
self
.
_begin_index
=
None
@
property
def
step_index
(
self
):
...
...
@@ -64,6 +65,24 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
...
@@ -90,24 +109,31 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
self
.
ets
=
[]
self
.
_step_index
=
None
self
.
_begin_index
=
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
def
_
in
it_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
pos
=
1
if
len
(
indices
)
>
1
else
0
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
View file @
adcbe674
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
math
from
collections
import
defaultdict
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -140,27 +139,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
self
.
_index_counter
)
==
0
:
pos
=
1
if
len
(
indices
)
>
1
else
0
else
:
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
pos
=
self
.
_index_counter
[
timestep_int
]
return
indices
[
pos
].
item
()
@
property
def
init_noise_sigma
(
self
):
# standard deviation of the initial noise distribution
...
...
@@ -176,6 +157,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
...
...
@@ -295,11 +294,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
...
...
@@ -356,23 +352,29 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def
state_in_first_order
(
self
):
return
self
.
sample
is
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
def
_
in
it_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
pos
=
1
if
len
(
indices
)
>
1
else
0
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
...
...
@@ -406,10 +408,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
self
.
_index_counter
[
timestep_int
]
+=
1
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_interpol
=
self
.
sigmas_interpol
[
self
.
step_index
]
...
...
@@ -478,7 +476,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_
h
eu
n
_discrete.
Heun
DiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_eu
ler
_discrete.
Euler
DiscreteScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -495,7 +493,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
View file @
adcbe674
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
math
from
collections
import
defaultdict
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -140,27 +139,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
self
.
_index_counter
)
==
0
:
pos
=
1
if
len
(
indices
)
>
1
else
0
else
:
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
pos
=
self
.
_index_counter
[
timestep_int
]
return
indices
[
pos
].
item
()
@
property
def
init_noise_sigma
(
self
):
# standard deviation of the initial noise distribution
...
...
@@ -176,6 +157,24 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
...
...
@@ -280,34 +279,37 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
def
state_in_first_order
(
self
):
return
self
.
sample
is
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
def
_
in
it_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
pos
=
1
if
len
(
indices
)
>
1
else
0
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
...
...
@@ -388,10 +390,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
self
.
_index_counter
[
timestep_int
]
+=
1
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_interpol
=
self
.
sigmas_interpol
[
self
.
step_index
+
1
]
...
...
@@ -453,7 +451,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_
h
eu
n
_discrete.
Heun
DiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_eu
ler
_discrete.
Euler
DiscreteScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -470,7 +468,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_lcm.py
View file @
adcbe674
...
...
@@ -250,29 +250,54 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self
.
custom_timesteps
=
False
self
.
_step_index
=
None
self
.
_begin_index
=
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
def
_
in
it_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
pos
=
1
if
len
(
indices
)
>
1
else
0
return
indices
[
pos
].
item
()
self
.
_step_index
=
step_index
.
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
@
property
def
step_index
(
self
):
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Optional
[
int
]
=
None
)
->
torch
.
FloatTensor
:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
...
...
@@ -462,6 +487,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
long
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
def
get_scalings_for_boundary_condition_discrete
(
self
,
timestep
):
self
.
sigma_data
=
0.5
# Default: 0.5
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
adcbe674
...
...
@@ -168,6 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
is_scale_input_called
=
False
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
...
...
@@ -185,6 +186,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
]
)
->
torch
.
FloatTensor
:
...
...
@@ -280,27 +299,34 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
derivatives
=
[]
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
def
_
in
it_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
pos
=
1
if
len
(
indices
)
>
1
else
0
return
indices
[
pos
].
item
()
self
.
_step_index
=
step_index
.
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
...
...
@@ -434,7 +460,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_sasolver.py
View file @
adcbe674
...
...
@@ -212,6 +212,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
...
...
@@ -221,6 +222,24 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
...
@@ -283,6 +302,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...
...
@@ -925,11 +945,12 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
x_t
=
x_t
.
to
(
x
.
dtype
)
return
x_t
def
_init_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
index_candidates
=
(
s
elf
.
timesteps
==
timestep
).
nonzero
()
index_candidates
=
(
s
chedule_
timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
self
.
timesteps
)
-
1
...
...
@@ -942,7 +963,20 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
else
:
step_index
=
index_candidates
[
0
].
item
()
self
.
_step_index
=
step_index
return
step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
"""
Initialize the step_index counter for the scheduler.
"""
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
...
...
src/diffusers/schedulers/scheduling_unipc_multistep.py
View file @
adcbe674
...
...
@@ -198,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self
.
solver_p
=
solver_p
self
.
last_sample
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
...
...
@@ -207,6 +208,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
...
@@ -269,6 +288,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...
...
@@ -698,11 +718,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t
=
x_t
.
to
(
x
.
dtype
)
return
x_t
def
_init_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
index_candidates
=
(
s
elf
.
timesteps
==
timestep
).
nonzero
()
index_candidates
=
(
s
chedule_
timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
self
.
timesteps
)
-
1
...
...
@@ -715,7 +736,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
else
:
step_index
=
index_candidates
[
0
].
item
()
self
.
_step_index
=
step_index
return
step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
"""
Initialize the step_index counter for the scheduler.
"""
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
...
...
@@ -830,16 +864,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[]
for
timestep
in
timesteps
:
index_candidates
=
(
schedule_timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
schedule_timesteps
)
-
1
elif
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
].
item
()
else
:
step_index
=
index_candidates
[
0
].
item
()
step_indices
.
append
(
step_index
)
# begin_index is None when the scheduler is used for training
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment