"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "e20362e0d5fef9e87e34c003ead3feafa9bba24c"
Commit 33e5a831 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

finish DDIM

parent 9fdbc14e
...@@ -19,12 +19,6 @@ import tqdm ...@@ -19,12 +19,6 @@ import tqdm
import torch import torch
def compute_alpha(beta, t):
beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
return a
class DDIM(DiffusionPipeline): class DDIM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
...@@ -32,7 +26,7 @@ class DDIM(DiffusionPipeline): ...@@ -32,7 +26,7 @@ class DDIM(DiffusionPipeline):
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50): def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
# eta is η in paper # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -59,15 +53,19 @@ class DDIM(DiffusionPipeline): ...@@ -59,15 +53,19 @@ class DDIM(DiffusionPipeline):
coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt() coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt()
# model forward
with torch.no_grad(): with torch.no_grad():
noise_residual = self.unet(image, train_step) noise_residual = self.unet(image, train_step)
print(train_step) # predict mean of prev image
pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual)
pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual
pred_mean = (image - noise_residual * beta_prod_t_sqrt) * alpha_prod_t_rsqrt # if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
xt_next = alpha_prod_t_prev.sqrt() * pred_mean + coeff_1 * torch.randn_like(image) + coeff_2 * noise_residual if eta > 0.0:
# xt_next = 1 / alpha_prod_t_rsqrt * pred_mean + coeff_1 * torch.randn_like(image) + coeff_2 * noise_residual noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
# eta image = pred_mean + coeff_1 * noise
image = xt_next else:
image = pred_mean
return image return image
...@@ -6,7 +6,12 @@ import numpy as np ...@@ -6,7 +6,12 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import repeat, rearrange
try:
from einops import repeat, rearrange
except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment