Commit 3a5c65d5 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

finish

parent 2032ad93
...@@ -60,36 +60,40 @@ posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20)) ...@@ -60,36 +60,40 @@ posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1) sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)
torch.manual_seed(0)
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
x_t = dummy_noise x_t = dummy_noise
# 2: for t = T, ...., 1 do
for i in reversed(range(TIME_STEPS)): for i in reversed(range(TIME_STEPS)):
# t for x_t
t = torch.tensor([i]) t = torch.tensor([i])
torch.manual_seed(0) # 3: z ~ N(0, 1)
noise = noise_like(x_t.shape, "cpu") noise = noise_like(x_t.shape, "cpu")
x_t2 = diffusion.p_sample(unet, x_t, t, noise=noise) # 4: √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
# ------------------------- MODEL ------------------------------------# # ------------------------- MODEL ------------------------------------#
# predict epsilon pred_noise = unet(x_t, t) # pred epsilon_theta
pred_noise = unet(x_t, t)
pred_x = extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * pred_noise pred_x = extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * pred_noise
pred_x.clamp_(-1.0, 1.0) pred_x.clamp_(-1.0, 1.0)
# pred mean
posterior_mean = extract(posterior_mean_coef1, t, x_t.shape) * pred_x + extract(posterior_mean_coef2, t, x_t.shape) * x_t posterior_mean = extract(posterior_mean_coef1, t, x_t.shape) * pred_x + extract(posterior_mean_coef2, t, x_t.shape) * x_t
# --------------------------------------------------------------------# # --------------------------------------------------------------------#
# predict x_{t-1} (=pred_x)
# ------------------------- Variance Scheduler -----------------------# # ------------------------- Variance Scheduler -----------------------#
# pred variance
posterior_log_variance = extract(posterior_log_variance_clipped, t, x_t.shape) posterior_log_variance = extract(posterior_log_variance_clipped, t, x_t.shape)
# no noise when t == 0
b, *_, device = *x_t.shape, x_t.device b, *_, device = *x_t.shape, x_t.device
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp() posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
# --------------------------------------------------------------------# # --------------------------------------------------------------------#
x_t = posterior_mean + posterior_variance * noise x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32)
x_t = x_t.to(torch.float32)
# FOR PATRICK TO VERIFY: make sure manual loop is equal to function
# --------------------------------------------------------------------#
x_t_12 = diffusion.p_sample(unet, x_t, t, noise=noise)
assert (x_t_1 - x_t_12).abs().sum().item() < 1e-3
# --------------------------------------------------------------------#
# make sure manual loop is equal to function x_t = x_t_1
assert (x_t - x_t2).abs().sum().item() < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment