Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
26b4319a
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "354d35adb02e943d79014e5713290a4551d3dd01"
Commit
26b4319a
authored
Apr 06, 2023
by
William Berman
Committed by
Will Berman
Apr 09, 2023
Browse files
do not overwrite scheduler instance variables with type casted versions
parent
18ebd57b
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
82 additions
and
59 deletions
+82
-59
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+8
-6
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+6
-6
src/diffusers/schedulers/scheduling_deis_multistep.py
src/diffusers/schedulers/scheduling_deis_multistep.py
+4
-3
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+4
-3
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+4
-3
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
...ffusers/schedulers/scheduling_euler_ancestral_discrete.py
+5
-5
src/diffusers/schedulers/scheduling_euler_discrete.py
src/diffusers/schedulers/scheduling_euler_discrete.py
+4
-5
src/diffusers/schedulers/scheduling_heun_discrete.py
src/diffusers/schedulers/scheduling_heun_discrete.py
+11
-7
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
...users/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+13
-7
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+13
-7
src/diffusers/schedulers/scheduling_lms_discrete.py
src/diffusers/schedulers/scheduling_lms_discrete.py
+1
-0
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+5
-4
src/diffusers/schedulers/scheduling_unipc_multistep.py
src/diffusers/schedulers/scheduling_unipc_multistep.py
+4
-3
No files found.
src/diffusers/schedulers/scheduling_ddim.py
View file @
26b4319a
...
@@ -380,6 +380,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -380,6 +380,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return
DDIMSchedulerOutput
(
prev_sample
=
prev_sample
,
pred_original_sample
=
pred_original_sample
)
return
DDIMSchedulerOutput
(
prev_sample
=
prev_sample
,
pred_original_sample
=
pred_original_sample
)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -387,15 +388,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -387,15 +388,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
@@ -403,19 +404,20 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -403,19 +404,20 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples
=
sqrt_alpha_prod
*
original_samples
+
sqrt_one_minus_alpha_prod
*
noise
noisy_samples
=
sqrt_alpha_prod
*
original_samples
+
sqrt_one_minus_alpha_prod
*
noise
return
noisy_samples
return
noisy_samples
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def
get_velocity
(
def
get_velocity
(
self
,
sample
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
IntTensor
self
,
sample
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
IntTensor
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
sample
.
device
,
dtype
=
sample
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
sample
.
device
,
dtype
=
sample
.
dtype
)
timesteps
=
timesteps
.
to
(
sample
.
device
)
timesteps
=
timesteps
.
to
(
sample
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
26b4319a
...
@@ -380,15 +380,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -380,15 +380,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
@@ -400,15 +400,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -400,15 +400,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self
,
sample
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
IntTensor
self
,
sample
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
IntTensor
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
sample
.
device
,
dtype
=
sample
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
sample
.
device
,
dtype
=
sample
.
dtype
)
timesteps
=
timesteps
.
to
(
sample
.
device
)
timesteps
=
timesteps
.
to
(
sample
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
sample
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_deis_multistep.py
View file @
26b4319a
...
@@ -477,6 +477,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -477,6 +477,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
sample
return
sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -484,15 +485,15 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -484,15 +485,15 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
View file @
26b4319a
...
@@ -527,6 +527,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -527,6 +527,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
sample
return
sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -534,15 +535,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -534,15 +535,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
View file @
26b4319a
...
@@ -602,6 +602,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -602,6 +602,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
sample
return
sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -609,15 +610,15 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -609,15 +610,15 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
View file @
26b4319a
...
@@ -279,6 +279,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -279,6 +279,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample
=
prev_sample
,
pred_original_sample
=
pred_original_sample
prev_sample
=
prev_sample
,
pred_original_sample
=
pred_original_sample
)
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -286,19 +287,18 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -286,19 +287,18 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
FloatTensor
,
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
# mps does not support float64
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
else
:
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
sigma
=
sigma
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_euler_discrete.py
View file @
26b4319a
...
@@ -360,19 +360,18 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -360,19 +360,18 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
FloatTensor
,
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
# mps does not support float64
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
else
:
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
schedule_timesteps
=
self
.
timesteps
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
step_indices
=
[(
schedule_timesteps
==
t
).
nonzero
().
item
()
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
sigma
=
sigma
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_heun_discrete.py
View file @
26b4319a
...
@@ -112,8 +112,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -112,8 +112,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
def
index_for_timestep
(
self
,
timestep
):
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
indices
=
(
self
.
timesteps
==
timestep
).
nonzero
()
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
if
self
.
state_in_first_order
:
if
self
.
state_in_first_order
:
pos
=
-
1
pos
=
-
1
else
:
else
:
...
@@ -277,18 +281,18 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -277,18 +281,18 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
FloatTensor
,
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
# mps does not support float64
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
else
:
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
)
for
t
in
timesteps
]
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
sigma
=
sigma
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
View file @
26b4319a
...
@@ -114,8 +114,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -114,8 +114,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
def
index_for_timestep
(
self
,
timestep
):
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
indices
=
(
self
.
timesteps
==
timestep
).
nonzero
()
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
if
self
.
state_in_first_order
:
if
self
.
state_in_first_order
:
pos
=
-
1
pos
=
-
1
else
:
else
:
...
@@ -323,6 +328,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -323,6 +328,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -330,18 +336,18 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -330,18 +336,18 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
FloatTensor
,
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
# mps does not support float64
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
else
:
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
)
for
t
in
timesteps
]
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
sigma
=
sigma
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
View file @
26b4319a
...
@@ -113,8 +113,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -113,8 +113,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
# set all values
# set all values
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
self
.
set_timesteps
(
num_train_timesteps
,
None
,
num_train_timesteps
)
def
index_for_timestep
(
self
,
timestep
):
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
indices
=
(
self
.
timesteps
==
timestep
).
nonzero
()
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
):
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
if
self
.
state_in_first_order
:
if
self
.
state_in_first_order
:
pos
=
-
1
pos
=
-
1
else
:
else
:
...
@@ -304,6 +309,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -304,6 +309,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
return
SchedulerOutput
(
prev_sample
=
prev_sample
)
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -311,18 +317,18 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -311,18 +317,18 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
FloatTensor
,
timesteps
:
torch
.
FloatTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self
.
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
sigmas
=
self
.
sigmas
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
if
original_samples
.
device
.
type
==
"mps"
and
torch
.
is_floating_point
(
timesteps
):
# mps does not support float64
# mps does not support float64
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
,
dtype
=
torch
.
float32
)
else
:
else
:
s
elf
.
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
s
chedule_
timesteps
=
self
.
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
step_indices
=
[
self
.
index_for_timestep
(
t
)
for
t
in
timesteps
]
step_indices
=
[
self
.
index_for_timestep
(
t
,
schedule_timesteps
)
for
t
in
timesteps
]
sigma
=
self
.
sigmas
[
step_indices
].
flatten
()
sigma
=
sigmas
[
step_indices
].
flatten
()
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sigma
.
shape
)
<
len
(
original_samples
.
shape
):
sigma
=
sigma
.
unsqueeze
(
-
1
)
sigma
=
sigma
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_lms_discrete.py
View file @
26b4319a
...
@@ -284,6 +284,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...
@@ -284,6 +284,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return
LMSDiscreteSchedulerOutput
(
prev_sample
=
prev_sample
,
pred_original_sample
=
pred_original_sample
)
return
LMSDiscreteSchedulerOutput
(
prev_sample
=
prev_sample
,
pred_original_sample
=
pred_original_sample
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
26b4319a
...
@@ -398,22 +398,23 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -398,22 +398,23 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return
prev_sample
return
prev_sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
IntTensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Float
Tensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_unipc_multistep.py
View file @
26b4319a
...
@@ -604,6 +604,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -604,6 +604,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
"""
return
sample
return
sample
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def
add_noise
(
def
add_noise
(
self
,
self
,
original_samples
:
torch
.
FloatTensor
,
original_samples
:
torch
.
FloatTensor
,
...
@@ -611,15 +612,15 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
...
@@ -611,15 +612,15 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
timesteps
:
torch
.
IntTensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
device
=
original_samples
.
device
,
dtype
=
original_samples
.
dtype
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment