Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
adcbe674
Unverified
Commit
adcbe674
authored
Feb 01, 2024
by
YiYi Xu
Committed by
GitHub
Feb 01, 2024
Browse files
[refactor]Scheduler.set_begin_index (#6728)
parent
ec9840a5
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
313 additions
and
162 deletions
+313
-162
src/diffusers/schedulers/scheduling_heun_discrete.py
src/diffusers/schedulers/scheduling_heun_discrete.py
+33
-31
src/diffusers/schedulers/scheduling_ipndm.py
src/diffusers/schedulers/scheduling_ipndm.py
+36
-10
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
...users/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+42
-40
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+42
-40
src/diffusers/schedulers/scheduling_lcm.py
src/diffusers/schedulers/scheduling_lcm.py
+36
-10
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+41
-11
src/diffusers/schedulers/scheduling_sasolver.py
src/diffusers/schedulers/scheduling_sasolver.py
+39
-5
src/diffusers/schedulers/scheduling_unipc_multistep.py
src/diffusers/schedulers/scheduling_unipc_multistep.py
+44
-15
No files found.
src/diffusers/schedulers/scheduling_heun_discrete.py
View file @
adcbe674
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
from
collections
import
defaultdict
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -148,8 +147,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -148,8 +147,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
schedule_timesteps
=
self
.
timesteps
...
@@ -160,11 +161,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -160,11 +161,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# is always the second index (or the last index if there is only 1)
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
self
.
_index_counter
)
==
0
:
pos
=
1
if
len
(
indices
)
>
1
else
0
pos
=
1
if
len
(
indices
)
>
1
else
0
else
:
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
pos
=
self
.
_index_counter
[
timestep_int
]
return
indices
[
pos
].
item
()
return
indices
[
pos
].
item
()
...
@@ -183,6 +180,24 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -183,6 +180,24 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
def
scale_model_input
(
self
,
self
,
sample
:
torch
.
FloatTensor
,
sample
:
torch
.
FloatTensor
,
...
@@ -270,13 +285,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -270,13 +285,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
dt
=
None
self
.
dt
=
None
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
# get log sigma
# get log sigma
...
@@ -333,21 +344,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -333,21 +344,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
def
_init_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
if
self
.
begin_index
is
None
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
index_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
self
.
_begin_index
self
.
_step_index
=
step_index
.
item
()
def
step
(
def
step
(
self
,
self
,
...
@@ -378,11 +380,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -378,11 +380,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
if
self
.
step_index
is
None
:
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
self
.
_init_step_index
(
timestep
)
# (YiYi notes: keep this for now since we are keeping the add_noise method)
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
self
.
_index_counter
[
timestep_int
]
+=
1
if
self
.
state_in_first_order
:
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_next
=
self
.
sigmas
[
self
.
step_index
+
1
]
sigma_next
=
self
.
sigmas
[
self
.
step_index
+
1
]
...
@@ -453,6 +450,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -453,6 +450,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -469,7 +467,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -469,7 +467,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_ipndm.py
View file @
adcbe674
...
@@ -56,6 +56,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -56,6 +56,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
# running values
# running values
self
.
ets
=
[]
self
.
ets
=
[]
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
@
property
@
property
def
step_index
(
self
):
def
step_index
(
self
):
...
@@ -64,6 +65,24 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -64,6 +65,24 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
@@ -90,24 +109,31 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -90,24 +109,31 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
self
.
ets
=
[]
self
.
ets
=
[]
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
_
in
it_step_index
(
self
,
timestep
):
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
if
schedule_timesteps
is
None
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
pos
=
1
if
len
(
indices
)
>
1
else
0
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
def
step
(
self
,
self
,
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
View file @
adcbe674
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
from
collections
import
defaultdict
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -140,27 +139,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -140,27 +139,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
self
.
_index_counter
)
==
0
:
pos
=
1
if
len
(
indices
)
>
1
else
0
else
:
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
pos
=
self
.
_index_counter
[
timestep_int
]
return
indices
[
pos
].
item
()
@
property
@
property
def
init_noise_sigma
(
self
):
def
init_noise_sigma
(
self
):
# standard deviation of the initial noise distribution
# standard deviation of the initial noise distribution
...
@@ -176,6 +157,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -176,6 +157,24 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
def
scale_model_input
(
self
,
self
,
sample
:
torch
.
FloatTensor
,
sample
:
torch
.
FloatTensor
,
...
@@ -295,11 +294,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -295,11 +294,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
self
.
sample
=
None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
...
@@ -356,23 +352,29 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -356,23 +352,29 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
def
state_in_first_order
(
self
):
def
state_in_first_order
(
self
):
return
self
.
sample
is
None
return
self
.
sample
is
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
_
in
it_step_index
(
self
,
timestep
):
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
if
schedule_timesteps
is
None
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
pos
=
1
if
len
(
indices
)
>
1
else
0
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
def
step
(
self
,
self
,
...
@@ -406,10 +408,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -406,10 +408,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
if
self
.
step_index
is
None
:
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
self
.
_init_step_index
(
timestep
)
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
self
.
_index_counter
[
timestep_int
]
+=
1
if
self
.
state_in_first_order
:
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_interpol
=
self
.
sigmas_interpol
[
self
.
step_index
]
sigma_interpol
=
self
.
sigmas_interpol
[
self
.
step_index
]
...
@@ -478,7 +476,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -478,7 +476,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_
h
eu
n
_discrete.
Heun
DiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_eu
ler
_discrete.
Euler
DiscreteScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -495,7 +493,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -495,7 +493,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
View file @
adcbe674
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
from
collections
import
defaultdict
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -140,27 +139,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -140,27 +139,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
self
.
_index_counter
)
==
0
:
pos
=
1
if
len
(
indices
)
>
1
else
0
else
:
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
pos
=
self
.
_index_counter
[
timestep_int
]
return
indices
[
pos
].
item
()
@
property
@
property
def
init_noise_sigma
(
self
):
def
init_noise_sigma
(
self
):
# standard deviation of the initial noise distribution
# standard deviation of the initial noise distribution
...
@@ -176,6 +157,24 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -176,6 +157,24 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
def
scale_model_input
(
self
,
self
,
sample
:
torch
.
FloatTensor
,
sample
:
torch
.
FloatTensor
,
...
@@ -280,34 +279,37 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -280,34 +279,37 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
self
.
sample
=
None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
def
state_in_first_order
(
self
):
def
state_in_first_order
(
self
):
return
self
.
sample
is
None
return
self
.
sample
is
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
_
in
it_step_index
(
self
,
timestep
):
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
if
schedule_timesteps
is
None
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
pos
=
1
if
len
(
indices
)
>
1
else
0
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
...
@@ -388,10 +390,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -388,10 +390,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
if
self
.
step_index
is
None
:
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
self
.
_init_step_index
(
timestep
)
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
self
.
_index_counter
[
timestep_int
]
+=
1
if
self
.
state_in_first_order
:
if
self
.
state_in_first_order
:
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma
=
self
.
sigmas
[
self
.
step_index
]
sigma_interpol
=
self
.
sigmas_interpol
[
self
.
step_index
+
1
]
sigma_interpol
=
self
.
sigmas_interpol
[
self
.
step_index
+
1
]
...
@@ -453,7 +451,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -453,7 +451,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_
h
eu
n
_discrete.
Heun
DiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_eu
ler
_discrete.
Euler
DiscreteScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -470,7 +468,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -470,7 +468,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_lcm.py
View file @
adcbe674
...
@@ -250,29 +250,54 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -250,29 +250,54 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self
.
custom_timesteps
=
False
self
.
custom_timesteps
=
False
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
_
in
it_step_index
(
self
,
timestep
):
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
if
schedule_timesteps
is
None
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
pos
=
1
if
len
(
indices
)
>
1
else
0
step_index
=
index_candidates
[
1
]
else
:
return
indices
[
pos
].
item
()
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
@
property
@
property
def
step_index
(
self
):
def
step_index
(
self
):
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Optional
[
int
]
=
None
)
->
torch
.
FloatTensor
:
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Optional
[
int
]
=
None
)
->
torch
.
FloatTensor
:
"""
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
...
@@ -462,6 +487,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -462,6 +487,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
long
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
,
dtype
=
torch
.
long
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
def
get_scalings_for_boundary_condition_discrete
(
self
,
timestep
):
def
get_scalings_for_boundary_condition_discrete
(
self
,
timestep
):
self
.
sigma_data
=
0.5
# Default: 0.5
self
.
sigma_data
=
0.5
# Default: 0.5
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
adcbe674
...
@@ -168,6 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -168,6 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
is_scale_input_called
=
False
self
.
is_scale_input_called
=
False
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
...
@@ -185,6 +186,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -185,6 +186,24 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
]
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
]
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
...
@@ -280,27 +299,34 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -280,27 +299,34 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
device
=
device
)
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
derivatives
=
[]
self
.
derivatives
=
[]
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
_
in
it_step_index
(
self
,
timestep
):
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
if
schedule_timesteps
is
None
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
pos
=
1
if
len
(
indices
)
>
1
else
0
step_index
=
index_candidates
[
1
]
else
:
return
indices
[
pos
].
item
()
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
...
@@ -434,7 +460,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -434,7 +460,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_sasolver.py
View file @
adcbe674
...
@@ -212,6 +212,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
...
@@ -212,6 +212,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
self
.
lower_order_nums
=
0
self
.
lower_order_nums
=
0
self
.
last_sample
=
None
self
.
last_sample
=
None
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
...
@@ -221,6 +222,24 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
...
@@ -221,6 +222,24 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
def
set_timesteps
(
self
,
num_inference_steps
:
int
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
@@ -283,6 +302,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
...
@@ -283,6 +302,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...
@@ -925,11 +945,12 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
...
@@ -925,11 +945,12 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
x_t
=
x_t
.
to
(
x
.
dtype
)
x_t
=
x_t
.
to
(
x
.
dtype
)
return
x_t
return
x_t
def
_init_step_index
(
self
,
timestep
):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
if
isinstance
(
timestep
,
torch
.
Tensor
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
index_candidates
=
(
s
elf
.
timesteps
==
timestep
).
nonzero
()
index_candidates
=
(
s
chedule_
timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
self
.
timesteps
)
-
1
step_index
=
len
(
self
.
timesteps
)
-
1
...
@@ -942,7 +963,20 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
...
@@ -942,7 +963,20 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
else
:
else
:
step_index
=
index_candidates
[
0
].
item
()
step_index
=
index_candidates
[
0
].
item
()
self
.
_step_index
=
step_index
return
step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
"""
Initialize the step_index counter for the scheduler.
"""
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
def
step
(
self
,
self
,
...
...
src/diffusers/schedulers/scheduling_unipc_multistep.py
View file @
adcbe674
...
@@ -198,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -198,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self
.
solver_p
=
solver_p
self
.
solver_p
=
solver_p
self
.
last_sample
=
None
self
.
last_sample
=
None
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
...
@@ -207,6 +208,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -207,6 +208,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
@@ -269,6 +288,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -269,6 +288,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...
@@ -698,11 +718,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -698,11 +718,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
x_t
=
x_t
.
to
(
x
.
dtype
)
x_t
=
x_t
.
to
(
x
.
dtype
)
return
x_t
return
x_t
def
_init_step_index
(
self
,
timestep
):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
if
isinstance
(
timestep
,
torch
.
Tensor
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
index_candidates
=
(
s
elf
.
timesteps
==
timestep
).
nonzero
()
index_candidates
=
(
s
chedule_
timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
self
.
timesteps
)
-
1
step_index
=
len
(
self
.
timesteps
)
-
1
...
@@ -715,7 +736,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -715,7 +736,20 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
else
:
else
:
step_index
=
index_candidates
[
0
].
item
()
step_index
=
index_candidates
[
0
].
item
()
self
.
_step_index
=
step_index
return
step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
"""
Initialize the step_index counter for the scheduler.
"""
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
def
step
(
self
,
self
,
...
@@ -830,16 +864,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -830,16 +864,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[]
# begin_index is None when the scheduler is used for training
for
timestep
in
timesteps
:
if
self
.
begin_index
is
None
:
index_candidates
=
(
schedule_timesteps
==
timestep
).
nonzero
()
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
if
len
(
index_candidates
)
==
0
:
else
:
step_index
=
len
(
schedule_timesteps
)
-
1
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
elif
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
].
item
()
else
:
step_index
=
index_candidates
[
0
].
item
()
step_indices
.
append
(
step_index
)
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment