Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
26b4319a
Commit
26b4319a
authored
Apr 06, 2023
by
William Berman
Committed by
Will Berman
Apr 09, 2023
Browse files
do not overwrite scheduler instance variables with type casted versions
parent
18ebd57b
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
82 additions
and
59 deletions
+82
-59
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+8
-6
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+6
-6
src/diffusers/schedulers/scheduling_deis_multistep.py
src/diffusers/schedulers/scheduling_deis_multistep.py
+4
-3
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+4
-3
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+4
-3
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
...ffusers/schedulers/scheduling_euler_ancestral_discrete.py
+5
-5
src/diffusers/schedulers/scheduling_euler_discrete.py
src/diffusers/schedulers/scheduling_euler_discrete.py
+4
-5
src/diffusers/schedulers/scheduling_heun_discrete.py
src/diffusers/schedulers/scheduling_heun_discrete.py
+11
-7
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
...users/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+13
-7
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+13
-7
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+1
-0
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+5
-4
src/diffusers/schedulers/scheduling_unipc_multistep.py
src/diffusers/schedulers/scheduling_unipc_multistep.py
+4
-3
No files found.
src/diffusers/schedulers/scheduling_ddim.py
View file @
26b4319a
...
...
@@ -380,6 +380,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return
DDIMSchedulerOutput
(
prev_sample
=
prev_sample
,
pred_original_sample
=
pred_original_sample
)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -387,15 +388,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
@@ -403,19 +404,20 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples
=
sqrt_alpha_prod
*
original_samples
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def
get_velocity
(
self
,
sample
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
IntTensor
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
sample
.
device
,
dtype
=
sample
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
sample
.
device
,
dtype
=
sample
.
dtype
)
timesteps
=
timesteps
.
to
(
sample
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
26b4319a
...
...
@@ -380,15 +380,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
@@ -400,15 +400,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self
,
sample
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
IntTensor
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
sample
.
device
,
dtype
=
sample
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
sample
.
device
,
dtype
=
sample
.
dtype
)
timesteps
=
timesteps
.
to
(
sample
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_deis_multistep.py
View file @
26b4319a
...
...
@@ -477,6 +477,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
return
sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -484,15 +485,15 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
26b4319a
...
...
@@ -527,6 +527,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
return
sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -534,15 +535,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
View file @
26b4319a
...
...
@@ -602,6 +602,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
return
sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -609,15 +610,15 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
View file @
26b4319a
...
...
@@ -279,6 +279,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample
=
prev_sample
,
pred_original_sample
=
pred_original_sample
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -286,19 +287,18 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_euler_discrete.py
View file @
26b4319a
...
...
@@ -360,19 +360,18 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_heun_discrete.py
View file @
26b4319a
...
...
@@ -112,8 +112,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
def
index_for_timestep
(
self
,
timestep
):
indices
=
(
self
.
timesteps
==
timestep
).
nonzero
()
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
if
self
.
state_in_first_order
:
pos
=
-
1
else
:
...
...
@@ -277,18 +281,18 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
)
for
t
in
timesteps
]
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
View file @
26b4319a
...
...
@@ -114,8 +114,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
def
index_for_timestep
(
self
,
timestep
):
indices
=
(
self
.
timesteps
==
timestep
).
nonzero
()
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
if
self
.
state_in_first_order
:
pos
=
-
1
else
:
...
...
@@ -323,6 +328,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -330,18 +336,18 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
)
for
t
in
timesteps
]
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
View file @
26b4319a
...
...
@@ -113,8 +113,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
def
index_for_timestep
(
self
,
timestep
):
indices
=
(
self
.
timesteps
==
timestep
).
nonzero
()
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
if
self
.
state_in_first_order
:
pos
=
-
1
else
:
...
...
@@ -304,6 +309,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -311,18 +317,18 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
)
for
t
in
timesteps
]
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
26b4319a
...
...
@@ -284,6 +284,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
LMSDiscreteSchedulerOutput
(
prev_sample
=
prev_sample
,
pred_original_sample
=
pred_original_sample
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
26b4319a
...
...
@@ -398,22 +398,23 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
prev_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Float
Tensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_unipc_multistep.py
View file @
26b4319a
...
...
@@ -604,6 +604,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
return
sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
self
,
original_samples
:
torch
.
FloatTensor
,
...
...
@@ -611,15 +612,15 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment