Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
adcbe674
Unverified
Commit
adcbe674
authored
Feb 01, 2024
by
YiYi Xu
Committed by
GitHub
Feb 01, 2024
Browse files
[refactor]Scheduler.set_begin_index (#6728)
parent
ec9840a5
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
307 additions
and
117 deletions
+307
-117
src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
...users/pipelines/controlnet/pipeline_controlnet_img2img.py
+2
-0
src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
...users/pipelines/controlnet/pipeline_controlnet_inpaint.py
+2
-0
src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
...pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
+2
-0
src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
...ted/stable_diffusion_variants/pipeline_cycle_diffusion.py
+2
-0
src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
...sion_variants/pipeline_stable_diffusion_inpaint_legacy.py
+2
-0
src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
...consistency_models/pipeline_latent_consistency_img2img.py
+2
-0
src/diffusers/pipelines/pia/pipeline_pia.py
src/diffusers/pipelines/pia/pipeline_pia.py
+2
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
...s/stable_diffusion/pipeline_stable_diffusion_depth2img.py
+2
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
...nes/stable_diffusion/pipeline_stable_diffusion_img2img.py
+2
-0
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
...nes/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+2
-0
src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
..._diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
+2
-0
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
...o_video_synthesis/pipeline_text_to_video_synth_img2img.py
+2
-0
src/diffusers/schedulers/scheduling_consistency_models.py
src/diffusers/schedulers/scheduling_consistency_models.py
+41
-18
src/diffusers/schedulers/scheduling_deis_multistep.py
src/diffusers/schedulers/scheduling_deis_multistep.py
+44
-15
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+41
-15
src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
...sers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+0
-2
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
+33
-31
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+44
-15
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
...ffusers/schedulers/scheduling_euler_ancestral_discrete.py
+41
-11
src/diffusers/schedulers/scheduling_euler_discrete.py
src/diffusers/schedulers/scheduling_euler_discrete.py
+39
-10
No files found.
src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
View file @
adcbe674
...
@@ -789,6 +789,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
...
@@ -789,6 +789,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
View file @
adcbe674
...
@@ -705,6 +705,8 @@ class StableDiffusionControlNetInpaintPipeline(
...
@@ -705,6 +705,8 @@ class StableDiffusionControlNetInpaintPipeline(
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
View file @
adcbe674
...
@@ -871,6 +871,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
...
@@ -871,6 +871,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
View file @
adcbe674
...
@@ -566,6 +566,8 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
...
@@ -566,6 +566,8 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
View file @
adcbe674
...
@@ -536,6 +536,8 @@ class StableDiffusionInpaintPipelineLegacy(
...
@@ -536,6 +536,8 @@ class StableDiffusionInpaintPipelineLegacy(
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
View file @
adcbe674
...
@@ -634,6 +634,8 @@ class LatentConsistencyModelImg2ImgPipeline(
...
@@ -634,6 +634,8 @@ class LatentConsistencyModelImg2ImgPipeline(
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/pipelines/pia/pipeline_pia.py
View file @
adcbe674
...
@@ -906,6 +906,8 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
...
@@ -906,6 +906,8 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
View file @
adcbe674
...
@@ -467,6 +467,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
...
@@ -467,6 +467,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
View file @
adcbe674
...
@@ -659,6 +659,8 @@ class StableDiffusionImg2ImgPipeline(
...
@@ -659,6 +659,8 @@ class StableDiffusionImg2ImgPipeline(
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
View file @
adcbe674
...
@@ -859,6 +859,8 @@ class StableDiffusionInpaintPipeline(
...
@@ -859,6 +859,8 @@ class StableDiffusionInpaintPipeline(
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
View file @
adcbe674
...
@@ -754,6 +754,8 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
...
@@ -754,6 +754,8 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
View file @
adcbe674
...
@@ -554,6 +554,8 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
...
@@ -554,6 +554,8 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
t_start
=
max
(
num_inference_steps
-
init_timestep
,
0
)
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
timesteps
=
self
.
scheduler
.
timesteps
[
t_start
*
self
.
scheduler
.
order
:]
if
hasattr
(
self
.
scheduler
,
"set_begin_index"
):
self
.
scheduler
.
set_begin_index
(
t_start
*
self
.
scheduler
.
order
)
return
timesteps
,
num_inference_steps
-
t_start
return
timesteps
,
num_inference_steps
-
t_start
...
...
src/diffusers/schedulers/scheduling_consistency_models.py
View file @
adcbe674
...
@@ -98,15 +98,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -98,15 +98,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self
.
custom_timesteps
=
False
self
.
custom_timesteps
=
False
self
.
is_scale_input_called
=
False
self
.
is_scale_input_called
=
False
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
return
indices
.
item
()
@
property
@
property
def
step_index
(
self
):
def
step_index
(
self
):
"""
"""
...
@@ -114,6 +108,24 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -114,6 +108,24 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
]
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
]
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
...
@@ -231,6 +243,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -231,6 +243,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Modified _convert_to_karras implementation that takes in ramp as argument
# Modified _convert_to_karras implementation that takes in ramp as argument
...
@@ -280,23 +293,29 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -280,23 +293,29 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
c_out
=
(
sigma
-
sigma_min
)
*
sigma_data
/
(
sigma
**
2
+
sigma_data
**
2
)
**
0.5
c_out
=
(
sigma
-
sigma_min
)
*
sigma_data
/
(
sigma
**
2
+
sigma_data
**
2
)
**
0.5
return
c_skip
,
c_out
return
c_skip
,
c_out
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
_
in
it_step_index
(
self
,
timestep
):
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
if
schedule_timesteps
is
None
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
pos
=
1
if
len
(
indices
)
>
1
else
0
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
def
step
(
self
,
self
,
...
@@ -412,7 +431,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -412,7 +431,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_deis_multistep.py
View file @
adcbe674
...
@@ -187,6 +187,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -187,6 +187,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
lower_order_nums
=
0
self
.
lower_order_nums
=
0
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
...
@@ -196,6 +197,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -196,6 +197,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
@@ -255,6 +274,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -255,6 +274,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...
@@ -620,11 +640,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -620,11 +640,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
else
:
else
:
raise
NotImplementedError
(
"only support log-rho multistep deis now"
)
raise
NotImplementedError
(
"only support log-rho multistep deis now"
)
def
_init_step_index
(
self
,
timestep
):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
if
isinstance
(
timestep
,
torch
.
Tensor
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
index_candidates
=
(
s
elf
.
timesteps
==
timestep
).
nonzero
()
index_candidates
=
(
s
chedule_
timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
self
.
timesteps
)
-
1
step_index
=
len
(
self
.
timesteps
)
-
1
...
@@ -637,7 +658,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -637,7 +658,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
else
:
else
:
step_index
=
index_candidates
[
0
].
item
()
step_index
=
index_candidates
[
0
].
item
()
self
.
_step_index
=
step_index
return
step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
"""
Initialize the step_index counter for the scheduler.
"""
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
def
step
(
self
,
self
,
...
@@ -736,16 +770,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -736,16 +770,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[]
# begin_index is None when the scheduler is used for training
for
timestep
in
timesteps
:
if
self
.
begin_index
is
None
:
index_candidates
=
(
schedule_timesteps
==
timestep
).
nonzero
()
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
schedule_timesteps
)
-
1
elif
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
].
item
()
else
:
else
:
step_index
=
index_candidates
[
0
].
item
()
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
step_indices
.
append
(
step_index
)
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
adcbe674
...
@@ -227,6 +227,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -227,6 +227,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
lower_order_nums
=
0
self
.
lower_order_nums
=
0
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
...
@@ -236,6 +237,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -236,6 +237,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
def
set_timesteps
(
self
,
num_inference_steps
:
int
=
None
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
@@ -311,6 +329,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -311,6 +329,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...
@@ -792,11 +811,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -792,11 +811,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
)
)
return
x_t
return
x_t
def
_
in
it_step_index
(
self
,
timestep
):
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
if
schedule_timesteps
is
None
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
schedule_
timestep
s
=
self
.
timesteps
index_candidates
=
(
s
elf
.
timesteps
==
timestep
).
nonzero
()
index_candidates
=
(
s
chedule_
timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
self
.
timesteps
)
-
1
step_index
=
len
(
self
.
timesteps
)
-
1
...
@@ -809,7 +828,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -809,7 +828,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
else
:
else
:
step_index
=
index_candidates
[
0
].
item
()
step_index
=
index_candidates
[
0
].
item
()
self
.
_step_index
=
step_index
return
step_index
def
_init_step_index
(
self
,
timestep
):
"""
Initialize the step_index counter for the scheduler.
"""
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
def
step
(
self
,
self
,
...
@@ -920,16 +951,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -920,16 +951,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[]
# begin_index is None when the scheduler is used for training
for
timestep
in
timesteps
:
if
self
.
begin_index
is
None
:
index_candidates
=
(
schedule_timesteps
==
timestep
).
nonzero
()
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
schedule_timesteps
)
-
1
elif
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
].
item
()
else
:
else
:
step_index
=
index_candidates
[
0
].
item
()
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
step_indices
.
append
(
step_index
)
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
View file @
adcbe674
...
@@ -767,7 +767,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
...
@@ -767,7 +767,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
)
)
return
x_t
return
x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
def
_init_step_index
(
self
,
timestep
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
...
@@ -879,7 +878,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
...
@@ -879,7 +878,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
sample
return
sample
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
...
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
View file @
adcbe674
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
math
import
math
from
collections
import
defaultdict
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -198,9 +197,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -198,9 +197,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self
.
noise_sampler
=
None
self
.
noise_sampler
=
None
self
.
noise_sampler_seed
=
noise_sampler_seed
self
.
noise_sampler_seed
=
noise_sampler_seed
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_
h
eu
n
_discrete.
Heun
DiscreteScheduler.index_for_timestep
# Copied from diffusers.schedulers.scheduling_eu
ler
_discrete.
Euler
DiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
schedule_timesteps
=
self
.
timesteps
...
@@ -211,31 +211,18 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -211,31 +211,18 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
# is always the second index (or the last index if there is only 1)
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
self
.
_index_counter
)
==
0
:
pos
=
1
if
len
(
indices
)
>
1
else
0
pos
=
1
if
len
(
indices
)
>
1
else
0
else
:
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
pos
=
self
.
_index_counter
[
timestep_int
]
return
indices
[
pos
].
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
index_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
]
else
:
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
self
.
_begin_index
self
.
_step_index
=
step_index
.
item
()
@
property
@
property
def
init_noise_sigma
(
self
):
def
init_noise_sigma
(
self
):
...
@@ -252,6 +239,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -252,6 +239,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
def
scale_model_input
(
self
,
self
,
sample
:
torch
.
FloatTensor
,
sample
:
torch
.
FloatTensor
,
...
@@ -348,13 +353,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -348,13 +353,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self
.
mid_point_sigma
=
None
self
.
mid_point_sigma
=
None
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
noise_sampler
=
None
self
.
noise_sampler
=
None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self
.
_index_counter
=
defaultdict
(
int
)
def
_second_order_timesteps
(
self
,
sigmas
,
log_sigmas
):
def
_second_order_timesteps
(
self
,
sigmas
,
log_sigmas
):
def
sigma_fn
(
_t
):
def
sigma_fn
(
_t
):
return
np
.
exp
(
-
_t
)
return
np
.
exp
(
-
_t
)
...
@@ -444,10 +446,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -444,10 +446,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
if
self
.
step_index
is
None
:
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
self
.
_init_step_index
(
timestep
)
# advance index counter by 1
timestep_int
=
timestep
.
cpu
().
item
()
if
torch
.
is_tensor
(
timestep
)
else
timestep
self
.
_index_counter
[
timestep_int
]
+=
1
# Create a noise sampler if it hasn't been created yet
# Create a noise sampler if it hasn't been created yet
if
self
.
noise_sampler
is
None
:
if
self
.
noise_sampler
is
None
:
min_sigma
,
max_sigma
=
self
.
sigmas
[
self
.
sigmas
>
0
].
min
(),
self
.
sigmas
.
max
()
min_sigma
,
max_sigma
=
self
.
sigmas
[
self
.
sigmas
>
0
].
min
(),
self
.
sigmas
.
max
()
...
@@ -527,7 +525,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -527,7 +525,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_
h
eu
n
_discrete.
Heun
DiscreteScheduler.add_noise
# Copied from diffusers.schedulers.scheduling_eu
ler
_discrete.
Euler
DiscreteScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -544,7 +542,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -544,7 +542,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
View file @
adcbe674
...
@@ -210,6 +210,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -210,6 +210,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
self
.
sample
=
None
self
.
order_list
=
self
.
get_order_list
(
num_train_timesteps
)
self
.
order_list
=
self
.
get_order_list
(
num_train_timesteps
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
def
get_order_list
(
self
,
num_inference_steps
:
int
)
->
List
[
int
]:
def
get_order_list
(
self
,
num_inference_steps
:
int
)
->
List
[
int
]:
...
@@ -253,6 +254,24 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -253,6 +254,24 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
):
"""
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...
@@ -315,6 +334,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -315,6 +334,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
...
@@ -813,11 +833,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -813,11 +833,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else
:
else
:
raise
ValueError
(
f
"Order must be 1, 2, 3, got
{
order
}
"
)
raise
ValueError
(
f
"Order must be 1, 2, 3, got
{
order
}
"
)
def
_init_step_index
(
self
,
timestep
):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
if
isinstance
(
timestep
,
torch
.
Tensor
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
index_candidates
=
(
s
elf
.
timesteps
==
timestep
).
nonzero
()
index_candidates
=
(
s
chedule_
timesteps
==
timestep
).
nonzero
()
if
len
(
index_candidates
)
==
0
:
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
self
.
timesteps
)
-
1
step_index
=
len
(
self
.
timesteps
)
-
1
...
@@ -830,7 +851,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -830,7 +851,20 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
else
:
else
:
step_index
=
index_candidates
[
0
].
item
()
step_index
=
index_candidates
[
0
].
item
()
self
.
_step_index
=
step_index
return
step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
"""
Initialize the step_index counter for the scheduler.
"""
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
def
step
(
self
,
self
,
...
@@ -925,16 +959,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -925,16 +959,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[]
# begin_index is None when the scheduler is used for training
for
timestep
in
timesteps
:
if
self
.
begin_index
is
None
:
index_candidates
=
(
schedule_timesteps
==
timestep
).
nonzero
()
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
if
len
(
index_candidates
)
==
0
:
step_index
=
len
(
schedule_timesteps
)
-
1
elif
len
(
index_candidates
)
>
1
:
step_index
=
index_candidates
[
1
].
item
()
else
:
else
:
step_index
=
index_candidates
[
0
].
item
()
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
step_indices
.
append
(
step_index
)
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
View file @
adcbe674
...
@@ -216,6 +216,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -216,6 +216,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
is_scale_input_called
=
False
self
.
is_scale_input_called
=
False
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
...
@@ -233,6 +234,24 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -233,6 +234,24 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
]
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
]
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
...
@@ -300,25 +319,32 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -300,25 +319,32 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
_
in
it_step_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.in
dex_for_timestep
def
_
in
it_step_index
(
self
,
timestep
):
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
if
schedule_timesteps
is
None
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
pos
=
1
if
len
(
indices
)
>
1
else
0
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
def
step
(
self
,
self
,
...
@@ -440,7 +466,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -440,7 +466,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_euler_discrete.py
View file @
adcbe674
...
@@ -237,6 +237,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -237,6 +237,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
...
@@ -255,6 +256,24 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -255,6 +256,24 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
self
.
_step_index
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
scale_model_input
(
def
scale_model_input
(
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
]
self
,
sample
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
]
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
...
@@ -342,6 +361,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -342,6 +361,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sigmas
=
torch
.
cat
([
sigmas
,
torch
.
zeros
(
1
,
device
=
sigmas
.
device
)])
self
.
sigmas
=
torch
.
cat
([
sigmas
,
torch
.
zeros
(
1
,
device
=
sigmas
.
device
)])
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
_begin_index
=
None
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
sigmas
=
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
...
@@ -393,22 +413,27 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -393,22 +413,27 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas
=
(
max_inv_rho
+
ramp
*
(
min_inv_rho
-
max_inv_rho
))
**
rho
sigmas
=
(
max_inv_rho
+
ramp
*
(
min_inv_rho
-
max_inv_rho
))
**
rho
return
sigmas
return
sigmas
def
_
in
it_step_index
(
self
,
timestep
):
def
in
dex_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
isinstance
(
timestep
,
torch
.
Tensor
)
:
if
schedule_timesteps
is
None
:
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
schedule_
timestep
s
=
self
.
timesteps
ind
ex_candidates
=
(
self
.
timesteps
==
timestep
).
nonzero
()
ind
ices
=
(
schedule_
timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if
len
(
index_candidates
)
>
1
:
pos
=
1
if
len
(
indices
)
>
1
else
0
step_index
=
index_candidates
[
1
]
else
:
step_index
=
index_candidates
[
0
]
self
.
_step_index
=
step_index
.
item
()
return
indices
[
pos
].
item
()
def
_init_step_index
(
self
,
timestep
):
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
step
(
def
step
(
self
,
self
,
...
@@ -538,7 +563,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -538,7 +563,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if
self
.
begin_index
is
None
:
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
else
:
step_indices
=
[
self
.
begin_index
]
*
timesteps
.
shape
[
0
]
sigma
=
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment