Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
01cf7392
"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "0467c9d74c9b34f91df905ed8cf8433de48d7fa5"
Commit
01cf7392
authored
Jun 10, 2022
by
Patrick von Platen
Browse files
correct more
parent
a14d774b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
20 deletions
+11
-20
models/vision/ddim/modeling_ddim.py
models/vision/ddim/modeling_ddim.py
+2
-14
src/diffusers/schedulers/ddim.py
src/diffusers/schedulers/ddim.py
+9
-6
No files found.
models/vision/ddim/modeling_ddim.py
View file @
01cf7392
...
@@ -30,9 +30,6 @@ class DDIM(DiffusionPipeline):
...
@@ -30,9 +30,6 @@ class DDIM(DiffusionPipeline):
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
num_trained_timesteps
=
self
.
noise_scheduler
.
num_timesteps
inference_step_times
=
range
(
0
,
num_trained_timesteps
,
num_trained_timesteps
//
num_inference_steps
)
self
.
unet
.
to
(
torch_device
)
self
.
unet
.
to
(
torch_device
)
# Sample gaussian noise to begin loop
# Sample gaussian noise to begin loop
...
@@ -42,20 +39,11 @@ class DDIM(DiffusionPipeline):
...
@@ -42,20 +39,11 @@ class DDIM(DiffusionPipeline):
generator
=
generator
,
generator
=
generator
,
)
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
for
t
in
tqdm
.
tqdm
(
reversed
(
range
(
num_inference_steps
)),
total
=
num_inference_steps
):
# 1. predict noise residual
# 1. predict noise residual
orig_t
=
self
.
noise_scheduler
.
get_orig_t
(
t
,
num_inference_steps
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
residual
=
self
.
unet
(
image
,
inference_step_times
[
t
]
)
residual
=
self
.
unet
(
image
,
orig_t
)
# 2. predict previous mean of image x_t-1
# 2. predict previous mean of image x_t-1
pred_prev_image
=
self
.
noise_scheduler
.
compute_prev_image_step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
pred_prev_image
=
self
.
noise_scheduler
.
compute_prev_image_step
(
residual
,
image
,
t
,
num_inference_steps
,
eta
)
...
...
src/diffusers/schedulers/ddim.py
View file @
01cf7392
...
@@ -87,9 +87,14 @@ class DDIMScheduler(nn.Module, ConfigMixin):
...
@@ -87,9 +87,14 @@ class DDIMScheduler(nn.Module, ConfigMixin):
return
torch
.
tensor
(
1.0
)
return
torch
.
tensor
(
1.0
)
return
self
.
alphas_cumprod
[
time_step
]
return
self
.
alphas_cumprod
[
time_step
]
def
get_orig_t
(
self
,
t
,
num_inference_steps
):
if
t
<
0
:
return
-
1
return
self
.
num_timesteps
//
num_inference_steps
*
t
def
get_variance
(
self
,
t
,
num_inference_steps
):
def
get_variance
(
self
,
t
,
num_inference_steps
):
orig_t
=
(
self
.
num_timesteps
//
num_inference_steps
)
*
t
orig_t
=
self
.
get_orig_t
(
t
,
num_inference_steps
)
orig_prev_t
=
(
self
.
num_timesteps
//
num_inference_steps
)
*
(
t
-
1
)
if
t
>
0
else
-
1
orig_prev_t
=
self
.
get_orig_t
(
t
-
1
,
num_inference_steps
)
alpha_prod_t
=
self
.
get_alpha_prod
(
orig_t
)
alpha_prod_t
=
self
.
get_alpha_prod
(
orig_t
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
orig_prev_t
)
alpha_prod_t_prev
=
self
.
get_alpha_prod
(
orig_prev_t
)
...
@@ -113,10 +118,8 @@ class DDIMScheduler(nn.Module, ConfigMixin):
...
@@ -113,10 +118,8 @@ class DDIMScheduler(nn.Module, ConfigMixin):
# - pred_prev_image -> "x_t-1"
# - pred_prev_image -> "x_t-1"
# 1. get actual t and t-1
# 1. get actual t and t-1
orig_t
=
(
self
.
num_timesteps
//
num_inference_steps
)
*
t
orig_t
=
self
.
get_orig_t
(
t
,
num_inference_steps
)
orig_prev_t
=
(
self
.
num_timesteps
//
num_inference_steps
)
*
(
t
-
1
)
if
t
>
0
else
-
1
orig_prev_t
=
self
.
get_orig_t
(
t
-
1
,
num_inference_steps
)
# train_step = inference_step_times[t]
# prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# 2. compute alphas, betas
# 2. compute alphas, betas
alpha_prod_t
=
self
.
get_alpha_prod
(
orig_t
)
alpha_prod_t
=
self
.
get_alpha_prod
(
orig_t
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment