Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4836cfad
Unverified
Commit
4836cfad
authored
Dec 15, 2023
by
Patrick von Platen
Committed by
GitHub
Dec 15, 2023
Browse files
[Sigmas] Keep sigmas on CPU (#6173)
* correct * Apply suggestions from code review * make style
parent
1ccbfbb6
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
26 additions
and
0 deletions
+26
-0
src/diffusers/schedulers/scheduling_consistency_models.py
src/diffusers/schedulers/scheduling_consistency_models.py
+2
-0
src/diffusers/schedulers/scheduling_deis_multistep.py
src/diffusers/schedulers/scheduling_deis_multistep.py
+2
-0
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+2
-0
src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
...sers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+2
-0
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
+2
-0
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+2
-0
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
...ffusers/schedulers/scheduling_euler_ancestral_discrete.py
+2
-0
src/diffusers/schedulers/scheduling_euler_discrete.py
src/diffusers/schedulers/scheduling_euler_discrete.py
+2
-0
src/diffusers/schedulers/scheduling_heun_discrete.py
src/diffusers/schedulers/scheduling_heun_discrete.py
+2
-0
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
...users/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+2
-0
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+2
-0
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+2
-0
src/diffusers/schedulers/scheduling_unipc_multistep.py
src/diffusers/schedulers/scheduling_unipc_multistep.py
+2
-0
No files found.
src/diffusers/schedulers/scheduling_consistency_models.py
View file @
4836cfad
...
@@ -98,6 +98,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -98,6 +98,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self
.
custom_timesteps
=
False
self
.
custom_timesteps
=
False
self
.
is_scale_input_called
=
False
self
.
is_scale_input_called
=
False
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
if
schedule_timesteps
is
None
:
...
@@ -230,6 +231,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
...
@@ -230,6 +231,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Modified _convert_to_karras implementation that takes in ramp as argument
# Modified _convert_to_karras implementation that takes in ramp as argument
def
_convert_to_karras
(
self
,
ramp
):
def
_convert_to_karras
(
self
,
ramp
):
...
...
src/diffusers/schedulers/scheduling_deis_multistep.py
View file @
4836cfad
...
@@ -187,6 +187,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -187,6 +187,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
lower_order_nums
=
0
self
.
lower_order_nums
=
0
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
def
step_index
(
self
):
def
step_index
(
self
):
...
@@ -254,6 +255,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -254,6 +255,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
4836cfad
...
@@ -214,6 +214,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -214,6 +214,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
lower_order_nums
=
0
self
.
lower_order_nums
=
0
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
def
step_index
(
self
):
def
step_index
(
self
):
...
@@ -290,6 +291,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -290,6 +291,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
View file @
4836cfad
...
@@ -209,6 +209,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
...
@@ -209,6 +209,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
model_outputs
=
[
None
]
*
solver_order
self
.
lower_order_nums
=
0
self
.
lower_order_nums
=
0
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
use_karras_sigmas
=
use_karras_sigmas
@
property
@
property
...
@@ -289,6 +290,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
...
@@ -289,6 +290,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
...
...
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
View file @
4836cfad
...
@@ -198,6 +198,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -198,6 +198,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self
.
noise_sampler
=
None
self
.
noise_sampler
=
None
self
.
noise_sampler_seed
=
noise_sampler_seed
self
.
noise_sampler_seed
=
noise_sampler_seed
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
...
@@ -347,6 +348,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
...
@@ -347,6 +348,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self
.
mid_point_sigma
=
None
self
.
mid_point_sigma
=
None
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
noise_sampler
=
None
self
.
noise_sampler
=
None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
...
...
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
View file @
4836cfad
...
@@ -197,6 +197,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -197,6 +197,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self
.
sample
=
None
self
.
sample
=
None
self
.
order_list
=
self
.
get_order_list
(
num_train_timesteps
)
self
.
order_list
=
self
.
get_order_list
(
num_train_timesteps
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
def
get_order_list
(
self
,
num_inference_steps
:
int
)
->
List
[
int
]:
def
get_order_list
(
self
,
num_inference_steps
:
int
)
->
List
[
int
]:
"""
"""
...
@@ -288,6 +289,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -288,6 +289,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
...
...
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
View file @
4836cfad
...
@@ -166,6 +166,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -166,6 +166,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
is_scale_input_called
=
False
self
.
is_scale_input_called
=
False
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
def
init_noise_sigma
(
self
):
def
init_noise_sigma
(
self
):
...
@@ -249,6 +250,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -249,6 +250,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def
_init_step_index
(
self
,
timestep
):
def
_init_step_index
(
self
,
timestep
):
...
...
src/diffusers/schedulers/scheduling_euler_discrete.py
View file @
4836cfad
...
@@ -237,6 +237,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -237,6 +237,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
def
init_noise_sigma
(
self
):
def
init_noise_sigma
(
self
):
...
@@ -341,6 +342,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -341,6 +342,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sigmas
=
torch
.
cat
([
sigmas
,
torch
.
zeros
(
1
,
device
=
sigmas
.
device
)])
self
.
sigmas
=
torch
.
cat
([
sigmas
,
torch
.
zeros
(
1
,
device
=
sigmas
.
device
)])
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
# get log sigma
# get log sigma
...
...
src/diffusers/schedulers/scheduling_heun_discrete.py
View file @
4836cfad
...
@@ -148,6 +148,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -148,6 +148,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
use_karras_sigmas
=
use_karras_sigmas
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
if
schedule_timesteps
is
None
:
...
@@ -269,6 +270,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -269,6 +270,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
dt
=
None
self
.
dt
=
None
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
View file @
4836cfad
...
@@ -140,6 +140,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -140,6 +140,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
...
@@ -295,6 +296,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -295,6 +296,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
def
_sigma_to_t
(
self
,
sigma
,
log_sigmas
):
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
View file @
4836cfad
...
@@ -140,6 +140,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -140,6 +140,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
...
@@ -284,6 +285,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -284,6 +285,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_index_counter
=
defaultdict
(
int
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
def
state_in_first_order
(
self
):
def
state_in_first_order
(
self
):
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
4836cfad
...
@@ -168,6 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -168,6 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
is_scale_input_called
=
False
self
.
is_scale_input_called
=
False
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
def
init_noise_sigma
(
self
):
def
init_noise_sigma
(
self
):
...
@@ -279,6 +280,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -279,6 +280,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
device
=
device
)
self
.
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
timesteps
=
torch
.
from_numpy
(
timesteps
).
to
(
device
=
device
)
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
self
.
derivatives
=
[]
self
.
derivatives
=
[]
...
...
src/diffusers/schedulers/scheduling_unipc_multistep.py
View file @
4836cfad
...
@@ -198,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -198,6 +198,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self
.
solver_p
=
solver_p
self
.
solver_p
=
solver_p
self
.
last_sample
=
None
self
.
last_sample
=
None
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
@
property
@
property
def
step_index
(
self
):
def
step_index
(
self
):
...
@@ -268,6 +269,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -268,6 +269,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# add an index counter for schedulers that allow duplicated timesteps
# add an index counter for schedulers that allow duplicated timesteps
self
.
_step_index
=
None
self
.
_step_index
=
None
self
.
sigmas
.
to
(
"cpu"
)
# to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
def
_threshold_sample
(
self
,
sample
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment