Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
01cf7392
Commit
01cf7392
authored
Jun 10, 2022
by
Patrick von Platen
Browse files
correct more
parent
a14d774b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
20 deletions
+11
-20
models/vision/ddim/modeling_ddim.py
models/vision/ddim/modeling_ddim.py
+2
-14
src/diffusers/schedulers/ddim.py
src/diffusers/schedulers/ddim.py
+9
-6
No files found.
models/vision/ddim/modeling_ddim.py
View file @
01cf7392
...
...
@@ -30,9 +30,6 @@ class DDIM(DiffusionPipeline):
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_trained_timesteps
=
self
.
noise_scheduler
.
num_timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
# Sample gaussian noise to begin loop
...
...
@@ -42,20 +39,11 @@ class DDIM(DiffusionPipeline):
generator
=
generator
,
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# 1. predict noise residual
orig_t
=
self
.
noise_scheduler
.
get_orig_t
(
t
,
num_inference_steps
)
with
torch
.
no_grad
():
residual
=
self
.
unet
(
image
,
inference_step_times
[
t
]
)
residual
=
self
.
unet
(
image
,
orig_t
)
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
compute_prev_image_step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
...
...
src/diffusers/schedulers/ddim.py
View file @
01cf7392
...
...
@@ -87,9 +87,14 @@ class DDIMScheduler(nn.Module, ConfigMixin):
return
torch
.
tensor
(
1.0
)
return
self
.
alphas_cumprod
[
time_step
]
def
get_orig_t
(
self
,
t
,
num_inference_steps
):
if
t
<
0
:
return
-
1
return
self
.
num_timesteps
//
num_inference_steps
*
t
def
get_variance
(
self
,
t
,
num_inference_steps
):
orig_t
=
(
self
.
num_timesteps
//
num_inference_steps
)
*
t
orig_prev_t
=
(
self
.
num_timesteps
//
num_inference_steps
)
*
(
t
-
1
)
if
t
>
0
else
-
1
orig_t
=
self
.
get_orig_t
(
t
,
num_inference_steps
)
orig_prev_t
=
self
.
get_orig_t
(
t
-
1
,
num_inference_steps
)
alpha_prod_t
=
self
.
get_alpha_prod
(
orig_t
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
orig_prev_t
)
...
...
@@ -113,10 +118,8 @@ class DDIMScheduler(nn.Module, ConfigMixin):
# - pred_prev_image -> "x_t-1"
# 1. get actual t and t-1
orig_t
=
(
self
.
num_timesteps
//
num_inference_steps
)
*
t
orig_prev_t
=
(
self
.
num_timesteps
//
num_inference_steps
)
*
(
t
-
1
)
if
t
>
0
else
-
1
# train_step = inference_step_times[t]
# prev_train_step = inference_step_times[t - 1] if t > 0 else -1
orig_t
=
self
.
get_orig_t
(
t
,
num_inference_steps
)
orig_prev_t
=
self
.
get_orig_t
(
t
-
1
,
num_inference_steps
)
# 2. compute alphas, betas
alpha_prod_t
=
self
.
get_alpha_prod
(
orig_t
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment