sample_loop.py 3.68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python3
from diffusers import UNetModel, GaussianDiffusion
import torch
import torch.nn.functional as F

unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
diffusion = GaussianDiffusion.from_config("fusing/ddpm_dummy")

# 2. Do one denoising step with model
batch_size, num_channels, height, width = 1, 3, 32, 32
dummy_noise = torch.ones((batch_size, num_channels, height, width))


TIME_STEPS = 10


# Helper
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def noise_like(shape, device, repeat=False):
    def repeat_noise():
        return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))

    def noise():
        return torch.randn(shape, device=device)

    return repeat_noise() if repeat else noise()


# Schedule
def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)


betas = cosine_beta_schedule(TIME_STEPS)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)

posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))


sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1)

Patrick von Platen's avatar
finish  
Patrick von Platen committed
63
torch.manual_seed(0)
64

Patrick von Platen's avatar
finish  
Patrick von Platen committed
65
66
# Compare the following to Algorithm 2 Sampling of paper: https://arxiv.org/pdf/2006.11239.pdf
# 1: x_t ~ N(0,1)
67
x_t = dummy_noise
Patrick von Platen's avatar
finish  
Patrick von Platen committed
68
# 2: for t = T, ...., 1 do
69
70
for i in reversed(range(TIME_STEPS)):
    t = torch.tensor([i])
Patrick von Platen's avatar
finish  
Patrick von Platen committed
71
    # 3: z ~ N(0, 1)
72
73
    noise = noise_like(x_t.shape, "cpu")

Patrick von Platen's avatar
finish  
Patrick von Platen committed
74
    # 4:  √1αtxt − √1−αt1−α¯tθ(xt, t) + σtz
75
    # ------------------------- MODEL ------------------------------------#
Patrick von Platen's avatar
finish  
Patrick von Platen committed
76
    pred_noise = unet(x_t, t)  # pred epsilon_theta
77
78
    pred_x = extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * pred_noise
    pred_x.clamp_(-1.0, 1.0)
Patrick von Platen's avatar
finish  
Patrick von Platen committed
79
    # pred mean
80
81
82
83
    posterior_mean = extract(posterior_mean_coef1, t, x_t.shape) * pred_x + extract(posterior_mean_coef2, t, x_t.shape) * x_t
    # --------------------------------------------------------------------#

    # ------------------------- Variance Scheduler -----------------------#
Patrick von Platen's avatar
finish  
Patrick von Platen committed
84
    # pred variance
85
86
87
88
89
90
    posterior_log_variance = extract(posterior_log_variance_clipped, t, x_t.shape)
    b, *_, device = *x_t.shape, x_t.device
    nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_t.shape) - 1)))
    posterior_variance = nonzero_mask * (0.5 * posterior_log_variance).exp()
    # --------------------------------------------------------------------#

Patrick von Platen's avatar
finish  
Patrick von Platen committed
91
92
93
94
95
96
97
    x_t_1 = (posterior_mean + posterior_variance * noise).to(torch.float32)

    # FOR PATRICK TO VERIFY: make sure manual loop is equal to function
    # --------------------------------------------------------------------#
    x_t_12 = diffusion.p_sample(unet, x_t, t, noise=noise)
    assert (x_t_1 - x_t_12).abs().sum().item() < 1e-3
    # --------------------------------------------------------------------#
98

Patrick von Platen's avatar
finish  
Patrick von Platen committed
99
    x_t = x_t_1