run.py 7.62 KB
Newer Older
1
2
3
4
#!/usr/bin/env python3
import numpy as np
import PIL
import torch
Patrick von Platen's avatar
Patrick von Platen committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import ml_collections
#from configs.ve import ffhq_ncsnpp_continuous as configs
#  from configs.ve import cifar10_ncsnpp_continuous as configs


# ffhq_ncsnpp_continuous config
def get_config():
  config = ml_collections.ConfigDict()
  # training
  config.training = training = ml_collections.ConfigDict()
  training.batch_size = 8
  training.n_iters = 2400001
  training.snapshot_freq = 50000
  training.log_freq = 50
  training.eval_freq = 100
  training.snapshot_freq_for_preemption = 5000
  training.snapshot_sampling = True
  training.sde = 'vesde'
  training.continuous = True
  training.likelihood_weighting = False
  training.reduce_mean = True

  # sampling
  config.sampling = sampling = ml_collections.ConfigDict()
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'
  sampling.probability_flow = False
  sampling.snr = 0.15
  sampling.n_steps_each = 1
  sampling.noise_removal = True

  # eval
  config.eval = evaluate = ml_collections.ConfigDict()
  evaluate.batch_size = 1024
  evaluate.num_samples = 50000
  evaluate.begin_ckpt = 1
  evaluate.end_ckpt = 96

  # data
  config.data = data = ml_collections.ConfigDict()
  data.dataset = 'FFHQ'
  data.image_size = 1024
  data.centered = False
  data.random_flip = True
  data.uniform_dequantization = False
  data.num_channels = 3
  # Plug in your own path to the tfrecords file.
  data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords'

  # model
  config.model = model = ml_collections.ConfigDict()
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.sigma_max = 1348
  model.num_scales = 2000
  model.ema_rate = 0.9999
  model.sigma_min = 0.01
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 16
  model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32)
  model.num_res_blocks = 1
  model.attn_resolutions = (16,)
  model.dropout = 0.
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'output_skip'
  model.progressive_input = 'input_skip'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3
  model.embedding_type = 'fourier'

  # optim
  config.optim = optim = ml_collections.ConfigDict()
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 2e-4
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 5000
  optim.grad_clip = 1.

  config.seed = 42
  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

  return config
100
101


Patrick von Platen's avatar
Patrick von Platen committed
102
torch.backends.cuda.matmul.allow_tf32 = False
Patrick von Platen's avatar
Patrick von Platen committed
103
torch.manual_seed(3)
104
105
106


class NewReverseDiffusionPredictor:
Patrick von Platen's avatar
Patrick von Platen committed
107
  def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
108
    super().__init__()
Patrick von Platen's avatar
Patrick von Platen committed
109
110
111
112
113
    self.sigma_min = sigma_min
    self.sigma_max = sigma_max
    self.N = N
    self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))

114
115
116
117
    self.probability_flow = probability_flow
    self.score_fn = score_fn

  def discretize(self, x, t):
Patrick von Platen's avatar
Patrick von Platen committed
118
119
    timestep = (t * (self.N - 1)).long()
    sigma = self.discrete_sigmas.to(t.device)[timestep]
120
    adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
Patrick von Platen's avatar
Patrick von Platen committed
121
                                 self.discrete_sigmas[timestep - 1].to(t.device))
122
123
124
    f = torch.zeros_like(x)
    G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)

Patrick von Platen's avatar
Patrick von Platen committed
125
    labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    result = self.score_fn(x, labels)

    rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
    rev_G = torch.zeros_like(G) if self.probability_flow else G
    return rev_f, rev_G

  def update_fn(self, x, t):
    f, G = self.discretize(x, t)
    z = torch.randn_like(x)
    x_mean = x - f
    x = x_mean + G[:, None, None, None] * z
    return x, x_mean


class NewLangevinCorrector:
Patrick von Platen's avatar
Patrick von Platen committed
141
  def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
142
143
144
145
146
    super().__init__()
    self.score_fn = score_fn
    self.snr = snr
    self.n_steps = n_steps

Patrick von Platen's avatar
Patrick von Platen committed
147
148
149
    self.sigma_min = sigma_min
    self.sigma_max = sigma_max

150
151
152
153
  def update_fn(self, x, t):
    score_fn = self.score_fn
    n_steps = self.n_steps
    target_snr = self.snr
Patrick von Platen's avatar
Patrick von Platen committed
154
155
156
157
158
#    if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
#      timestep = (t * (sde.N - 1) / sde.T).long()
#      alpha = sde.alphas.to(t.device)[timestep]
#    else:
    alpha = torch.ones_like(t)
159
160

    for i in range(n_steps):
Patrick von Platen's avatar
Patrick von Platen committed
161
      labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
      grad = score_fn(x, labels)
      noise = torch.randn_like(x)
      grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
      noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
      step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
      x_mean = x + step_size[:, None, None, None] * grad
      x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise

    return x, x_mean



def save_image(x):
    image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
    image_pil = PIL.Image.fromarray(image_processed[0])
    image_pil.save("../images/hey.png")


Patrick von Platen's avatar
Patrick von Platen committed
180
#  ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
Patrick von Platen's avatar
Patrick von Platen committed
181
182
183
184
#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below

185
186


Patrick von Platen's avatar
Patrick von Platen committed
187
config = get_config()  
188

Patrick von Platen's avatar
Patrick von Platen committed
189
190
sigma_min, sigma_max = config.model.sigma_min, config.model.sigma_max
N = config.model.num_scales
Patrick von Platen's avatar
Patrick von Platen committed
191

Patrick von Platen's avatar
Patrick von Platen committed
192
193
194
195
196
sampling_eps = 1e-5

batch_size = 1 #@param {"type":"integer"}
config.training.batch_size = batch_size
config.eval.batch_size = batch_size
Patrick von Platen's avatar
Patrick von Platen committed
197
198

from diffusers import NCSNpp
Patrick von Platen's avatar
Patrick von Platen committed
199
200
model = NCSNpp(config).to(config.device)
model = torch.nn.DataParallel(model)
201

Patrick von Platen's avatar
Patrick von Platen committed
202
loaded_state = torch.load("../score_sde_pytorch/ffhq_1024_ncsnpp_continuous_ema.pt")
Patrick von Platen's avatar
Patrick von Platen committed
203
del loaded_state["module.sigmas"]
Patrick von Platen's avatar
Patrick von Platen committed
204
model.load_state_dict(loaded_state, strict=False)
205

Patrick von Platen's avatar
Patrick von Platen committed
206
207
208
209
210
211
212
213
214
def get_data_inverse_scaler(config):
  """Inverse data normalizer."""
  if config.data.centered:
    # Rescale [-1, 1] to [0, 1]
    return lambda x: (x + 1.) / 2.
  else:
    return lambda x: x

inverse_scaler = get_data_inverse_scaler(config)
215
216
217
218
219

img_size = config.data.image_size
channels = config.data.num_channels
shape = (batch_size, channels, img_size, img_size)
probability_flow = False
Patrick von Platen's avatar
Patrick von Platen committed
220
snr = 0.15 #@param {"type": "number"}
221
222
223
n_steps =  1#@param {"type": "integer"}


Patrick von Platen's avatar
Patrick von Platen committed
224
device = config.device
225

Patrick von Platen's avatar
Patrick von Platen committed
226
227
new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
228
229
230

with torch.no_grad():
    # Initial sample
Patrick von Platen's avatar
Patrick von Platen committed
231
232
233
    x = torch.randn(*shape) * sigma_max
    x = x.to(device)
    timesteps = torch.linspace(1, sampling_eps, N, device=device)
234

Patrick von Platen's avatar
Patrick von Platen committed
235
    for i in range(N):
236
237
        t = timesteps[i]
        vec_t = torch.ones(shape[0], device=t.device) * t
Patrick von Platen's avatar
Patrick von Platen committed
238
239
        x, x_mean = new_corrector.update_fn(x, vec_t)
        x, x_mean = new_predictor.update_fn(x, vec_t)
240

Patrick von Platen's avatar
Patrick von Platen committed
241
    x = inverse_scaler(x_mean)
242

Patrick von Platen's avatar
up  
Patrick von Platen committed
243

Patrick von Platen's avatar
Patrick von Platen committed
244
save_image(x)
Patrick von Platen's avatar
Patrick von Platen committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

# for 5 cifar10
x_sum = 106071.9922
x_mean = 34.52864456176758

# for 1000 cifar10
x_sum = 461.9700
x_mean = 0.1504

# for 2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125

def check_x_sum_x_mean(x, x_sum, x_mean):
    assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
    assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
261
262


Patrick von Platen's avatar
Patrick von Platen committed
263
#check_x_sum_x_mean(x, x_sum, x_mean)