Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7e11392d
Commit
7e11392d
authored
Jul 19, 2022
by
Patrick von Platen
Browse files
fix ddpm scheduler
parent
1f49a343
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
9 deletions
+12
-9
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+2
-8
src/diffusers/schedulers/scheduling_ddpm.py
src/diffusers/schedulers/scheduling_ddpm.py
+10
-1
No files found.
src/diffusers/pipelines/ddpm/pipeline_ddpm.py
View file @
7e11392d
...
@@ -51,13 +51,7 @@ class DDPMPipeline(DiffusionPipeline):
...
@@ -51,13 +51,7 @@ class DDPMPipeline(DiffusionPipeline):
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
pred_prev_image
=
self
.
scheduler
.
step
(
model_output
,
t
,
image
)[
"prev_sample"
]
# 3. optionally sample variance
# 3. set current image to prev_image: x_t -> x_t-1
variance
=
0
image
=
pred_prev_image
if
t
>
0
:
noise
=
torch
.
randn
(
image
.
shape
,
generator
=
generator
).
to
(
image
.
device
)
variance
=
self
.
scheduler
.
get_variance
(
t
).
sqrt
()
*
noise
# 4. set current image to prev_image: x_t -> x_t-1
image
=
pred_prev_image
+
variance
return
{
"sample"
:
image
}
return
{
"sample"
:
image
}
src/diffusers/schedulers/scheduling_ddpm.py
View file @
7e11392d
...
@@ -101,7 +101,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -101,7 +101,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
)[::
-
1
].
copy
()
)[::
-
1
].
copy
()
self
.
set_format
(
tensor_format
=
self
.
tensor_format
)
self
.
set_format
(
tensor_format
=
self
.
tensor_format
)
def
get_variance
(
self
,
t
,
variance_type
=
None
):
def
_
get_variance
(
self
,
t
,
variance_type
=
None
):
alpha_prod_t
=
self
.
alphas_cumprod
[
t
]
alpha_prod_t
=
self
.
alphas_cumprod
[
t
]
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
-
1
]
if
t
>
0
else
self
.
one
alpha_prod_t_prev
=
self
.
alphas_cumprod
[
t
-
1
]
if
t
>
0
else
self
.
one
...
@@ -133,6 +133,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -133,6 +133,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
timestep
:
int
,
timestep
:
int
,
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
sample
:
Union
[
torch
.
FloatTensor
,
np
.
ndarray
],
predict_epsilon
=
True
,
predict_epsilon
=
True
,
generator
=
None
,
):
):
t
=
timestep
t
=
timestep
# 1. compute alphas, betas
# 1. compute alphas, betas
...
@@ -161,6 +162,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
...
@@ -161,6 +162,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample
=
pred_original_sample_coeff
*
pred_original_sample
+
current_sample_coeff
*
sample
pred_prev_sample
=
pred_original_sample_coeff
*
pred_original_sample
+
current_sample_coeff
*
sample
# 6. Add noise
variance
=
0
if
t
>
0
:
noise
=
torch
.
randn
(
model_output
.
shape
,
generator
=
generator
).
to
(
model_output
.
device
)
variance
=
self
.
_get_variance
(
t
).
sqrt
()
*
noise
pred_prev_sample
=
pred_prev_sample
+
variance
return
{
"prev_sample"
:
pred_prev_sample
}
return
{
"prev_sample"
:
pred_prev_sample
}
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment