Commit 7a1323b6 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add first version of ddim

parent 86064df7
...@@ -32,11 +32,16 @@ class DDIM(DiffusionPipeline): ...@@ -32,11 +32,16 @@ class DDIM(DiffusionPipeline):
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, inference_time_steps=50): def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, inference_time_steps=50):
seq = range(0, self.num_timesteps, self.num_timesteps // inference_time_steps) # eta is η in paper
b = self.noise_scheduler.betas
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_timesteps = self.noise_scheduler.num_timesteps
seq = range(0, num_timesteps, num_timesteps // inference_time_steps)
b = self.noise_scheduler.betas.to(torch_device)
self.unet.to(torch_device) self.unet.to(torch_device)
x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
...@@ -63,5 +68,4 @@ class DDIM(DiffusionPipeline): ...@@ -63,5 +68,4 @@ class DDIM(DiffusionPipeline):
xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
xs.append(xt_next.to('cpu')) xs.append(xt_next.to('cpu'))
import ipdb; ipdb.set_trace() return xt_next
return xs, x0_preds
...@@ -25,6 +25,7 @@ import torch ...@@ -25,6 +25,7 @@ import torch
from diffusers import GaussianDDPMScheduler, UNetModel from diffusers import GaussianDDPMScheduler, UNetModel
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from models.vision.ddpm.modeling_ddpm import DDPM from models.vision.ddpm.modeling_ddpm import DDPM
from models.vision.ddim.modeling_ddim import DDIM
global_rng = random.Random() global_rng = random.Random()
...@@ -205,6 +206,7 @@ class SamplerTesterMixin(unittest.TestCase): ...@@ -205,6 +206,7 @@ class SamplerTesterMixin(unittest.TestCase):
class PipelineTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
# 1. Load models # 1. Load models
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
...@@ -241,3 +243,31 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -241,3 +243,31 @@ class PipelineTesterMixin(unittest.TestCase):
new_image = ddpm_from_hub(generator=generator) new_image = ddpm_from_hub(generator=generator)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_ddpm_cifar10(self):
generator = torch.manual_seed(0)
model_id = "fusing/ddpm-cifar10"
ddpm = DDPM.from_pretrained(model_id)
image = ddpm(generator=generator)
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor([0.2250, 0.3375, 0.2360, 0.0930, 0.3440, 0.3156, 0.1937, 0.3585, 0.1761])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_ddim_cifar10(self):
generator = torch.manual_seed(0)
model_id = "fusing/ddpm-cifar10"
ddim = DDIM.from_pretrained(model_id)
image = ddim(generator=generator, eta=0.0)
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor([-0.7688, -0.7690, -0.7597, -0.7660, -0.7713, -0.7531, -0.7009, -0.7098, -0.7350])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment