Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
3a5c65d5
"1" did not exist on "25feac9e65ff7a7ca87d75150555bc010f3dfdd0"
Commit
3a5c65d5
authored
Jun 03, 2022
by
Patrick von Platen
Browse files
finish
parent
2032ad93
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
13 deletions
+17
-13
examples/sample_loop.py
examples/sample_loop.py
+17
-13
No files found.
examples/sample_loop.py
View file @
3a5c65d5
...
@@ -60,36 +60,40 @@ posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
...
@@ -60,36 +60,40 @@ posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
sqrt_recip_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
)
sqrt_recip_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
)
sqrt_recipm1_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
)
sqrt_recipm1_alphas_cumprod
=
torch
.
sqrt
(
1.0
/
alphas_cumprod
-
1
)
torch
.
manual_seed
(
0
)
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t
=
dummy_noise
x_t
=
dummy_noise
# 2: for t = T, ...., 1 do
for
i
in
reversed
(
range
(
TIME_STEPS
)):
for
i
in
reversed
(
range
(
TIME_STEPS
)):
# t for x_t
t
=
torch
.
tensor
([
i
])
t
=
torch
.
tensor
([
i
])
torch
.
manual_seed
(
0
)
# 3: z ~ N(0, 1
)
noise
=
noise_like
(
x_t
.
shape
,
"cpu"
)
noise
=
noise_like
(
x_t
.
shape
,
"cpu"
)
x_t2
=
diffusion
.
p_sample
(
unet
,
x_t
,
t
,
noise
=
noise
)
# 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------#
# ------------------------- MODEL ------------------------------------#
# predict epsilon
pred_noise
=
unet
(
x_t
,
t
)
# pred epsilon_theta
pred_noise
=
unet
(
x_t
,
t
)
pred_x
=
extract
(
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
pred_noise
pred_x
=
extract
(
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract
(
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
pred_noise
pred_x
.
clamp_
(
-
1.0
,
1.0
)
pred_x
.
clamp_
(
-
1.0
,
1.0
)
# pred mean
posterior_mean
=
extract
(
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
pred_x
+
extract
(
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
posterior_mean
=
extract
(
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
pred_x
+
extract
(
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
# --------------------------------------------------------------------#
# --------------------------------------------------------------------#
# predict x_{t-1} (=pred_x)
# ------------------------- Variance Scheduler -----------------------#
# ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance
=
extract
(
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
posterior_log_variance
=
extract
(
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
# no noise when t == 0
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
b
,
*
_
,
device
=
*
x_t
.
shape
,
x_t
.
device
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
nonzero_mask
=
(
1
-
(
t
==
0
).
float
()).
reshape
(
b
,
*
((
1
,)
*
(
len
(
x_t
.
shape
)
-
1
)))
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
posterior_variance
=
nonzero_mask
*
(
0.5
*
posterior_log_variance
).
exp
()
# --------------------------------------------------------------------#
# --------------------------------------------------------------------#
x_t
=
posterior_mean
+
posterior_variance
*
noise
x_t_1
=
(
posterior_mean
+
posterior_variance
*
noise
).
to
(
torch
.
float32
)
x_t
=
x_t
.
to
(
torch
.
float32
)
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function
# --------------------------------------------------------------------#
x_t_12
=
diffusion
.
p_sample
(
unet
,
x_t
,
t
,
noise
=
noise
)
assert
(
x_t_1
-
x_t_12
).
abs
().
sum
().
item
()
<
1e-3
# --------------------------------------------------------------------#
# make sure manual loop is equal to function
x_t
=
x_t_1
assert
(
x_t
-
x_t2
).
abs
().
sum
().
item
()
<
1e-3
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment