Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
33045382
Unverified
Commit
33045382
authored
Sep 27, 2022
by
Suraj Patil
Committed by
GitHub
Sep 27, 2022
Browse files
[DDIM, DDPM] fix add_noise (#648)
fix add noise
parent
e5eed523
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
5 deletions
+12
-5
src/diffusers/schedulers/scheduling_ddim.py
src/diffusers/schedulers/scheduling_ddim.py
+6
-1
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+6
-2
src/diffusers/schedulers/scheduling_pndm.py
src/diffusers/schedulers/scheduling_pndm.py
+0
-2
No files found.
src/diffusers/schedulers/scheduling_ddim.py
View file @
33045382
...
@@ -282,7 +282,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -282,7 +282,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
IntTensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
timesteps
=
timesteps
.
to
(
self
.
alphas_cumprod
.
device
)
if
self
.
alphas_cumprod
.
device
!=
original_samples
.
device
:
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
original_samples
.
device
)
if
timesteps
.
device
!=
original_samples
.
device
:
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
src/diffusers/schedulers/scheduling_ddpm.py
View file @
33045382
...
@@ -268,7 +268,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -268,7 +268,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noise
:
torch
.
FloatTensor
,
noise
:
torch
.
FloatTensor
,
timesteps
:
torch
.
IntTensor
,
timesteps
:
torch
.
IntTensor
,
)
->
torch
.
FloatTensor
:
)
->
torch
.
FloatTensor
:
timesteps
=
timesteps
.
to
(
self
.
alphas_cumprod
.
device
)
if
self
.
alphas_cumprod
.
device
!=
original_samples
.
device
:
self
.
alphas_cumprod
=
self
.
alphas_cumprod
.
to
(
original_samples
.
device
)
if
timesteps
.
device
!=
original_samples
.
device
:
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
...
@@ -276,7 +280,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -276,7 +280,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_alpha_prod
=
sqrt_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
=
(
1
-
self
.
alphas_cumprod
[
timesteps
])
**
0.5
sqrt_one_minus_alpha_prod
.
flatten
()
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
flatten
()
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_one_minus_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
sqrt_one_minus_alpha_prod
=
sqrt_one_minus_alpha_prod
.
unsqueeze
(
-
1
)
...
...
src/diffusers/schedulers/scheduling_pndm.py
View file @
33045382
...
@@ -387,8 +387,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -387,8 +387,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if
timesteps
.
device
!=
original_samples
.
device
:
if
timesteps
.
device
!=
original_samples
.
device
:
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
original_samples
.
device
)
timesteps
=
timesteps
.
to
(
self
.
alphas_cumprod
.
device
)
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
self
.
alphas_cumprod
[
timesteps
]
**
0.5
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
sqrt_alpha_prod
=
sqrt_alpha_prod
.
flatten
()
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
while
len
(
sqrt_alpha_prod
.
shape
)
<
len
(
original_samples
.
shape
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment