Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
adcbe674
Unverified
Commit
adcbe674
authored
Feb 01, 2024
by
YiYi Xu
Committed by
GitHub
Feb 01, 2024
Browse files
[refactor]Scheduler.set_begin_index (#6728)
parent
ec9840a5
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
313 additions
and
162 deletions
+313
-162
src/diffusers/schedulers/scheduling_heun_discrete.py
src/diffusers/schedulers/scheduling_heun_discrete.py
+33
-31
src/diffusers/schedulers/scheduling_ipndm.py
src/diffusers/schedulers/scheduling_ipndm.py
+36
-10
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
...users/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+42
-40
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+42
-40
src/diffusers/schedulers/scheduling_lcm.py
src/diffusers/schedulers/scheduling_lcm.py
+36
-10
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+41
-11
src/diffusers/schedulers/scheduling_sasolver.py
src/diffusers/schedulers/scheduling_sasolver.py
+39
-5
src/diffusers/schedulers/scheduling_unipc_multistep.py
src/diffusers/schedulers/scheduling_unipc_multistep.py
+44
-15
No files found.
src/diffusers/schedulers/scheduling_heun_discrete.py
View file @
adcbe674
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
math
from
collections
import
defaultdict
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -148,8 +147,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
...
...
@@ -160,11 +161,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
self
.
_index_counter
)
==
0
:
pos
=
1
if
len
(
indices
)
>
1
else
0
else
:
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
pos
=
self
.
_index_counter
[
timestep_int
]
pos
=
1
if
len
(
indices
)
>
1
else
0
return
indices
[
pos
].
item
()
...
...
@@ -183,6 +180,24 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
...
...
@@ -270,13 +285,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
dt
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
# get log sigma
...
...
@@ -333,21 +344,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
index_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
...
...
@@ -378,11 +380,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# (YiYi notes: keep this for now since we are keeping the add_noise method)
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
self
.
_index_counter
[
timestep_int
]
+=
1
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_next
=
self
.
sigmas
[
self
.
step_index
+
1
]
...
...
@@ -453,6 +450,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -469,7 +467,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_ipndm.py
View file @
adcbe674
...
...
@@ -56,6 +56,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
# running values
self
.
ets
=
[]
self
.
_step_index
=
None
self
.
_begin_index
=
None
@
property
def
step_index
(
self
):
...
...
@@ -64,6 +65,24 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
...
@@ -90,24 +109,31 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
self
.
ets
=
[]
self
.
_step_index
=
None
self
.
_begin_index
=
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
def
_
in
it_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
pos
=
1
if
len
(
indices
)
>
1
else
0
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
View file @
adcbe674
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
math
from
collections
import
defaultdict
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -140,27 +139,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
self
.
_index_counter
)
==
0
:
pos
=
1
if
len
(
indices
)
>
1
else
0
else
:
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
pos
=
self
.
_index_counter
[
timestep_int
]
return
indices
[
pos
].
item
()
@
property
def
init_noise_sigma
(
self
):
# standard deviation of the initial noise distribution
...
...
@@ -176,6 +157,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
...
...
@@ -295,11 +294,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
...
...
@@ -356,23 +352,29 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def
state_in_first_order
(
self
):
return
self
.
sample
is
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
def
_
in
it_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
pos
=
1
if
len
(
indices
)
>
1
else
0
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
...
...
@@ -406,10 +408,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
self
.
_index_counter
[
timestep_int
]
+=
1
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_interpol
=
self
.
sigmas_interpol
[
self
.
step_index
]
...
...
@@ -478,7 +476,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_
h
eu
n
_discrete.
Heun
DiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_eu
ler
_discrete.
Euler
DiscreteScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -495,7 +493,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
View file @
adcbe674
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
math
from
collections
import
defaultdict
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -140,27 +139,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
self
.
_index_counter
)
==
0
:
pos
=
1
if
len
(
indices
)
>
1
else
0
else
:
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
pos
=
self
.
_index_counter
[
timestep_int
]
return
indices
[
pos
].
item
()
@
property
def
init_noise_sigma
(
self
):
# standard deviation of the initial noise distribution
...
...
@@ -176,6 +157,24 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
...
...
@@ -280,34 +279,37 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
def
state_in_first_order
(
self
):
return
self
.
sample
is
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
def
_
in
it_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
pos
=
1
if
len
(
indices
)
>
1
else
0
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
...
...
@@ -388,10 +390,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
self
.
_index_counter
[
timestep_int
]
+=
1
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_interpol
=
self
.
sigmas_interpol
[
self
.
step_index
+
1
]
...
...
@@ -453,7 +451,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_
h
eu
n
_discrete.
Heun
DiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_eu
ler
_discrete.
Euler
DiscreteScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -470,7 +468,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_lcm.py
View file @
adcbe674
...
...
@@ -250,29 +250,54 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self
.
custom_timesteps
=
False
self
.
_step_index
=
None
self
.
_begin_index
=
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
def
_
in
it_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
pos
=
1
if
len
(
indices
)
>
1
else
0
return
indices
[
pos
].
item
()
self
.
_step_index
=
step_index
.
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
@
property
def
step_index
(
self
):
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Optional
[
int
]
=
None
)
->
torch
.
FloatTensor
:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
...
...
@@ -462,6 +487,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
long
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
def
get_scalings_for_boundary_condition_discrete
(
self
,
timestep
):
self
.
sigma_data
=
0.5
# Default: 0.5
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
adcbe674
...
...
@@ -168,6 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
is_scale_input_called
=
False
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
...
...
@@ -185,6 +186,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
]
)
->
torch
.
FloatTensor
:
...
...
@@ -280,27 +299,34 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
derivatives
=
[]
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
def
_
in
it_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
pos
=
1
if
len
(
indices
)
>
1
else
0
return
indices
[
pos
].
item
()
self
.
_step_index
=
step_index
.
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
...
...
@@ -434,7 +460,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_sasolver.py
View file @
adcbe674
...
...
@@ -212,6 +212,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
...
...
@@ -221,6 +222,24 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
...
@@ -283,6 +302,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...
...
@@ -925,11 +945,12 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
x_t
=
x_t
.
to
(
x
.
dtype
)
return
x_t
def
_init_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
index_candidates
=
(
s
elf
.
timesteps
==
timestep
).
nonzero
()
index_candidates
=
(
s
chedule_
timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
self
.
timesteps
)
-
1
...
...
@@ -942,7 +963,20 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
else
:
step_index
=
index_candidates
[
0
].
item
()
self
.
_step_index
=
step_index
return
step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
"""
Initialize the step_index counter for the scheduler.
"""
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
...
...
src/diffusers/schedulers/scheduling_unipc_multistep.py
View file @
adcbe674
...
...
@@ -198,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self
.
solver_p
=
solver_p
self
.
last_sample
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
...
...
@@ -207,6 +208,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
...
@@ -269,6 +288,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...
...
@@ -698,11 +718,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t
=
x_t
.
to
(
x
.
dtype
)
return
x_t
def
_init_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
index_candidates
=
(
s
elf
.
timesteps
==
timestep
).
nonzero
()
index_candidates
=
(
s
chedule_
timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
self
.
timesteps
)
-
1
...
...
@@ -715,7 +736,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
else
:
step_index
=
index_candidates
[
0
].
item
()
self
.
_step_index
=
step_index
return
step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
"""
Initialize the step_index counter for the scheduler.
"""
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
self
,
...
...
@@ -830,16 +864,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[]
for
timestep
in
timesteps
:
index_candidates
=
(
schedule_timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
schedule_timesteps
)
-
1
elif
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
].
item
()
else
:
step_index
=
index_candidates
[
0
].
item
()
step_indices
.
append
(
step_index
)
# begin_index is None when the scheduler is used for training
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment