Commit e8977e95 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

save intermediate

parent ee71a3b6
...@@ -31,33 +31,40 @@ class DDIM(DiffusionPipeline): ...@@ -31,33 +31,40 @@ class DDIM(DiffusionPipeline):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, inference_time_steps=50): def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
# eta is η in paper # eta is η in paper
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_timesteps = self.noise_scheduler.num_timesteps num_trained_timesteps = self.noise_scheduler.num_timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
seq = range(0, num_timesteps, num_timesteps // inference_time_steps)
b = self.noise_scheduler.betas.to(torch_device)
self.unet.to(torch_device) self.unet.to(torch_device)
x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
with torch.no_grad(): b = self.noise_scheduler.betas.to(torch_device)
seq = inference_step_times
seq_next = [-1] + list(seq[:-1])
# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# train_step = inference_step_times[t]
for i, j in zip(reversed(seq), reversed(seq_next)):
n = batch_size n = batch_size
seq_next = [-1] + list(seq[:-1])
x0_preds = [] x0_preds = []
xs = [x] xs = [x]
for i, j in zip(reversed(seq), reversed(seq_next)):
# i = train_step
# j = inference_step_times[t-1] if t > 0 else -1
if True:
print(i) print(i)
t = (torch.ones(n) * i).to(x.device) t = (torch.ones(n) * i).to(x.device)
next_t = (torch.ones(n) * j).to(x.device) next_t = (torch.ones(n) * j).to(x.device)
at = compute_alpha(b, t.long()) at = compute_alpha(b, t.long())
at_next = compute_alpha(b, next_t.long()) at_next = compute_alpha(b, next_t.long())
xt = xs[-1].to('cuda') xt = xs[-1].to('cuda')
et = self.unet(xt, t) with torch.no_grad():
et = self.unet(xt, t)
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
x0_preds.append(x0_t.to('cpu')) x0_preds.append(x0_t.to('cpu'))
# eta # eta
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment