Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
33e5a831
Commit
33e5a831
authored
Jun 08, 2022
by
Patrick von Platen
Browse files
finish DDIM
parent
9fdbc14e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
14 deletions
+17
-14
models/vision/ddim/modeling_ddim.py
models/vision/ddim/modeling_ddim.py
+11
-13
src/diffusers/models/unet_ldm.py
src/diffusers/models/unet_ldm.py
+6
-1
No files found.
models/vision/ddim/modeling_ddim.py
View file @
33e5a831
...
@@ -19,12 +19,6 @@ import tqdm
...
@@ -19,12 +19,6 @@ import tqdm
import
torch
import
torch
def
compute_alpha
(
beta
,
t
):
beta
=
torch
.
cat
([
torch
.
zeros
(
1
).
to
(
beta
.
device
),
beta
],
dim
=
0
)
a
=
(
1
-
beta
).
cumprod
(
dim
=
0
).
index_select
(
0
,
t
+
1
).
view
(
-
1
,
1
,
1
,
1
)
return
a
class
DDIM
(
DiffusionPipeline
):
class
DDIM
(
DiffusionPipeline
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
def
__init__
(
self
,
unet
,
noise_scheduler
):
...
@@ -32,7 +26,7 @@ class DDIM(DiffusionPipeline):
...
@@ -32,7 +26,7 @@ class DDIM(DiffusionPipeline):
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
self
.
register_modules
(
unet
=
unet
,
noise_scheduler
=
noise_scheduler
)
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
def
__call__
(
self
,
batch_size
=
1
,
generator
=
None
,
torch_device
=
None
,
eta
=
0.0
,
num_inference_steps
=
50
):
# eta
is η in paper
# eta
corresponds to η in paper and should be between [0, 1]
if
torch_device
is
None
:
if
torch_device
is
None
:
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
torch_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
@@ -59,15 +53,19 @@ class DDIM(DiffusionPipeline):
...
@@ -59,15 +53,19 @@ class DDIM(DiffusionPipeline):
coeff_1
=
(
alpha_prod_t_prev
-
alpha_prod_t
).
sqrt
()
*
alpha_prod_t_prev_rsqrt
*
beta_prod_t_prev_sqrt
/
beta_prod_t_sqrt
*
eta
coeff_1
=
(
alpha_prod_t_prev
-
alpha_prod_t
).
sqrt
()
*
alpha_prod_t_prev_rsqrt
*
beta_prod_t_prev_sqrt
/
beta_prod_t_sqrt
*
eta
coeff_2
=
((
1
-
alpha_prod_t_prev
)
-
coeff_1
**
2
).
sqrt
()
coeff_2
=
((
1
-
alpha_prod_t_prev
)
-
coeff_1
**
2
).
sqrt
()
# model forward
with
torch
.
no_grad
():
with
torch
.
no_grad
():
noise_residual
=
self
.
unet
(
image
,
train_step
)
noise_residual
=
self
.
unet
(
image
,
train_step
)
print
(
train_step
)
# predict mean of prev image
pred_mean
=
alpha_prod_t_rsqrt
*
(
image
-
beta_prod_t_sqrt
*
noise_residual
)
pred_mean
=
(
1
/
alpha_prod_t_prev_rsqrt
)
*
pred_mean
+
coeff_2
*
noise_residual
pred_mean
=
(
image
-
noise_residual
*
beta_prod_t_sqrt
)
*
alpha_prod_t_rsqrt
# if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
xt_next
=
alpha_prod_t_prev
.
sqrt
()
*
pred_mean
+
coeff_1
*
torch
.
randn_like
(
image
)
+
coeff_2
*
noise_residual
if
eta
>
0.0
:
# xt_next = 1 / alpha_prod_t_rsqrt * pred_mean + coeff_1 * torch.randn_like(image) + coeff_2 * noise_residual
noise
=
self
.
noise_scheduler
.
sample_noise
(
image
.
shape
,
device
=
image
.
device
,
generator
=
generator
)
# eta
image
=
pred_mean
+
coeff_1
*
noise
image
=
xt_next
else
:
image
=
pred_mean
return
image
return
image
src/diffusers/models/unet_ldm.py
View file @
33e5a831
...
@@ -6,7 +6,12 @@ import numpy as np
...
@@ -6,7 +6,12 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
repeat
,
rearrange
try
:
from
einops
import
repeat
,
rearrange
except
:
print
(
"Einops is not installed"
)
pass
from
..configuration_utils
import
ConfigMixin
from
..configuration_utils
import
ConfigMixin
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment