Commit 43e728d3 authored by anton-l's avatar anton-l
Browse files

Merge remote-tracking branch 'origin/main'

parents 383dc795 9fdbc14e
......@@ -31,41 +31,43 @@ class DDIM(DiffusionPipeline):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, inference_time_steps=50):
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
# eta is η in paper
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_timesteps = self.noise_scheduler.num_timesteps
seq = range(0, num_timesteps, num_timesteps // inference_time_steps)
b = self.noise_scheduler.betas.to(torch_device)
num_trained_timesteps = self.noise_scheduler.num_timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
with torch.no_grad():
n = batch_size
seq_next = [-1] + list(seq[:-1])
x0_preds = []
xs = [x]
for i, j in zip(reversed(seq), reversed(seq_next)):
print(i)
t = (torch.ones(n) * i).to(x.device)
next_t = (torch.ones(n) * j).to(x.device)
at = compute_alpha(b, t.long())
at_next = compute_alpha(b, next_t.long())
xt = xs[-1].to('cuda')
et = self.unet(xt, t)
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
x0_preds.append(x0_t.to('cpu'))
# eta
c1 = (
eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
)
c2 = ((1 - at_next) - c1 ** 2).sqrt()
xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
xs.append(xt_next.to('cpu'))
return xt_next
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1
train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else - 1
# compute alphas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt()
alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt()
beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
# compute relevant coefficients
coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt()
with torch.no_grad():
noise_residual = self.unet(image, train_step)
print(train_step)
pred_mean = (image - noise_residual * beta_prod_t_sqrt) * alpha_prod_t_rsqrt
xt_next = alpha_prod_t_prev.sqrt() * pred_mean + coeff_1 * torch.randn_like(image) + coeff_2 * noise_residual
# xt_next = 1 / alpha_prod_t_rsqrt * pred_mean + coeff_1 * torch.randn_like(image) + coeff_2 * noise_residual
# eta
image = xt_next
return image
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment