Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3a5c65d5
Commit
3a5c65d5
authored
Jun 03, 2022
by
Patrick von Platen
Browse files
finish
parent
2032ad93
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
13 deletions
+17
-13
examples/sample_loop.py
examples/sample_loop.py
+17
-13
No files found.
examples/sample_loop.py
View file @
3a5c65d5
...
@@ -60,36 +60,40 @@ posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
...
@@ -60,36 +60,40 @@ posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
sqrt_recip_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
)
sqrt_recip_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
)
sqrt_recipm1_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
)
sqrt_recipm1_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
)
torch
.
manual_seed
(
0
)
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t
=
dummy_noise
x_t
=
dummy_noise
# 2: for t = T, ...., 1 do
for
i
in
reversed
(
range
(
TIME_STEPS
)):
for
i
in
reversed
(
range
(
TIME_STEPS
)):
# t for x_t
t
=
torch
.
tensor
([
i
])
t
=
torch
.
tensor
([
i
])
torch
.
manual_seed
(
0
)
# 3: z ~ N(0, 1
)
noise
=
noise_like
(
x_t
.
shape
,
"cpu"
)
noise
=
noise_like
(
x_t
.
shape
,
"cpu"
)
x_t2
=
diffusion
.
p_sample
(
unet
,
x_t
,
t
,
noise
=
noise
)
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
# ------------------------- MODEL ------------------------------------#
# predict epsilon
pred_noise
=
unet
(
x_t
,
t
)
# pred epsilon_theta
pred_noise
=
unet
(
x_t
,
t
)
pred_x
=
extract
(
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
pred_noise
pred_x
=
extract
(
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
pred_noise
pred_x
.
clamp_
(
-
1.0
,
1.0
)
pred_x
.
clamp_
(
-
1.0
,
1.0
)
# pred mean
posterior_mean
=
extract
(
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
pred_x
+
extract
(
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
posterior_mean
=
extract
(
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
pred_x
+
extract
(
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
# --------------------------------------------------------------------#
# --------------------------------------------------------------------#
# predict x_{t-1} (=pred_x)
# ------------------------- Variance Scheduler -----------------------#
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance
=
extract
(
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
posterior_log_variance
=
extract
(
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
# no noise when t == 0
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
# --------------------------------------------------------------------#
# --------------------------------------------------------------------#
x_t
=
posterior_mean
+
posterior_variance
*
noise
x_t_1
=
(
posterior_mean
+
posterior_variance
*
noise
).
to
(
torch
.
float32
)
x_t
=
x_t
.
to
(
torch
.
float32
)
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function
# --------------------------------------------------------------------#
x_t_12
=
diffusion
.
p_sample
(
unet
,
x_t
,
t
,
noise
=
noise
)
assert
(
x_t_1
-
x_t_12
).
abs
().
sum
().
item
()
<
1e-3
# --------------------------------------------------------------------#
# make sure manual loop is equal to function
x_t
=
x_t_1
assert
(
x_t
-
x_t2
).
abs
().
sum
().
item
()
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment