run.py 4.61 KB
Newer Older
1
2
3
4
#!/usr/bin/env python3
import numpy as np
import PIL
import torch
Patrick von Platen's avatar
Patrick von Platen committed
5
6
7
8
#from configs.ve import ffhq_ncsnpp_continuous as configs
#  from configs.ve import cifar10_ncsnpp_continuous as configs


Patrick von Platen's avatar
Patrick von Platen committed
9
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
10

Patrick von Platen's avatar
Patrick von Platen committed
11
torch.backends.cuda.matmul.allow_tf32 = False
Patrick von Platen's avatar
Patrick von Platen committed
12
torch.manual_seed(0)
13
14
15


class NewReverseDiffusionPredictor:
Patrick von Platen's avatar
Patrick von Platen committed
16
  def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
17
    super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
18
19
20
21
22
    self.sigma_min = sigma_min
    self.sigma_max = sigma_max
    self.N = N
    self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))

23
24
25
26
    self.probability_flow = probability_flow
    self.score_fn = score_fn

  def discretize(self, x, t):
Patrick von Platen's avatar
Patrick von Platen committed
27
28
    timestep = (t * (self.N - 1)).long()
    sigma = self.discrete_sigmas.to(t.device)[timestep]
29
    adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
Patrick von Platen's avatar
Patrick von Platen committed
30
                                 self.discrete_sigmas[timestep - 1].to(t.device))
31
32
33
    f = torch.zeros_like(x)
    G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)

Patrick von Platen's avatar
Patrick von Platen committed
34
    labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    result = self.score_fn(x, labels)

    rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
    rev_G = torch.zeros_like(G) if self.probability_flow else G
    return rev_f, rev_G

  def update_fn(self, x, t):
    f, G = self.discretize(x, t)
    z = torch.randn_like(x)
    x_mean = x - f
    x = x_mean + G[:, None, None, None] * z
    return x, x_mean


class NewLangevinCorrector:
Patrick von Platen's avatar
Patrick von Platen committed
50
  def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
51
52
53
54
55
    super().__init__()
    self.score_fn = score_fn
    self.snr = snr
    self.n_steps = n_steps

Patrick von Platen's avatar
Patrick von Platen committed
56
57
58
    self.sigma_min = sigma_min
    self.sigma_max = sigma_max

59
60
61
62
  def update_fn(self, x, t):
    score_fn = self.score_fn
    n_steps = self.n_steps
    target_snr = self.snr
Patrick von Platen's avatar
Patrick von Platen committed
63
64
65
66
67
#    if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
#      timestep = (t * (sde.N - 1) / sde.T).long()
#      alpha = sde.alphas.to(t.device)[timestep]
#    else:
    alpha = torch.ones_like(t)
68
69

    for i in range(n_steps):
Patrick von Platen's avatar
Patrick von Platen committed
70
      labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
      grad = score_fn(x, labels)
      noise = torch.randn_like(x)
      grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
      noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
      step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
      x_mean = x + step_size[:, None, None, None] * grad
      x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise

    return x, x_mean



def save_image(x):
    image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
    image_pil = PIL.Image.fromarray(image_processed[0])
    image_pil.save("../images/hey.png")


Patrick von Platen's avatar
Patrick von Platen committed
89
#  ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
Patrick von Platen's avatar
Patrick von Platen committed
90
91
92
93
#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below

Patrick von Platen's avatar
Patrick von Platen committed
94
95
96
N = 2
sigma_min = 0.01
sigma_max = 1348
Patrick von Platen's avatar
Patrick von Platen committed
97
sampling_eps = 1e-5
Patrick von Platen's avatar
Patrick von Platen committed
98
99
batch_size = 1
centered = False
Patrick von Platen's avatar
Patrick von Platen committed
100
101

from diffusers import NCSNpp
Patrick von Platen's avatar
Patrick von Platen committed
102

Patrick von Platen's avatar
Patrick von Platen committed
103
104
model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device)
model = torch.nn.DataParallel(model)
105

Patrick von Platen's avatar
Patrick von Platen committed
106
107
img_size = model.module.config.image_size
channels = model.module.config.num_channels
108
109
shape = (batch_size, channels, img_size, img_size)
probability_flow = False
Patrick von Platen's avatar
Patrick von Platen committed
110
111
snr = 0.15
n_steps = 1
112
113


Patrick von Platen's avatar
Patrick von Platen committed
114
115
new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
116
117
118

with torch.no_grad():
    # Initial sample
Patrick von Platen's avatar
Patrick von Platen committed
119
120
121
    x = torch.randn(*shape) * sigma_max
    x = x.to(device)
    timesteps = torch.linspace(1, sampling_eps, N, device=device)
122

Patrick von Platen's avatar
Patrick von Platen committed
123
    for i in range(N):
124
125
        t = timesteps[i]
        vec_t = torch.ones(shape[0], device=t.device) * t
Patrick von Platen's avatar
Patrick von Platen committed
126
127
        x, x_mean = new_corrector.update_fn(x, vec_t)
        x, x_mean = new_predictor.update_fn(x, vec_t)
128

Patrick von Platen's avatar
Patrick von Platen committed
129
130
131
    x = x_mean
    if centered:
      x = (x + 1.) / 2.
132

Patrick von Platen's avatar
up  
Patrick von Platen committed
133

Patrick von Platen's avatar
Patrick von Platen committed
134
# save_image(x)
Patrick von Platen's avatar
Patrick von Platen committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

# for 5 cifar10
x_sum = 106071.9922
x_mean = 34.52864456176758

# for 1000 cifar10
x_sum = 461.9700
x_mean = 0.1504

# for 2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125

def check_x_sum_x_mean(x, x_sum, x_mean):
    assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
    assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
151
152


Patrick von Platen's avatar
Patrick von Platen committed
153
check_x_sum_x_mean(x, x_sum, x_mean)