Commit 6ab2dd18 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

up

parent fe313730
#!/usr/bin/env python3
import tempfile
import sys
from diffusers import GaussianDDPMScheduler, UNetModel
from modeling_ddpm import DDPM
model_id = sys.argv[1]
folder = sys.argv[2]
save = bool(int(sys.argv[3]))
unet = UNetModel.from_pretrained("fusing/ddpm_dummy")
sampler = GaussianDDPMScheduler.from_config("fusing/ddpm_dummy")
unet = UNetModel.from_pretrained(model_id)
sampler = GaussianDDPMScheduler.from_config(model_id)
# compose Diffusion Pipeline
ddpm = DDPM(unet, sampler)
# generate / sample
if save:
ddpm = DDPM(unet, sampler)
ddpm.save_pretrained(folder)
image = ddpm()
print(image)
import PIL.Image
import numpy as np
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png")
# save and load with 0 extra code (handled by general `DiffusionPipeline` class)
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
print("Model saved")
ddpm_new = DDPM.from_pretrained(tmpdirname)
print("Model loaded")
print(ddpm_new)
import ipdb; ipdb.set_trace()
......@@ -15,12 +15,43 @@
from diffusers import DiffusionPipeline
import tqdm
import torch
class DDPM(DiffusionPipeline):
def __init__(self, unet, gaussian_sampler):
super().__init__(unet=unet, gaussian_sampler=gaussian_sampler)
def __init__(self, unet, noise_scheduler):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, generator=None, torch_device=None):
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.unet.to(torch_device)
# 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((1, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
# i) define coefficients for time step t
clip_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
clip_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
clip_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
# ii) predict noise residual
with torch.no_grad():
noise_residual = self.unet(image, t)
# iii) compute predicted image from residual
# See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
pred_mean = clip_image_coeff * image - clip_noise_coeff * noise_residual
pred_mean = torch.clamp(pred_mean, -1, 1)
prev_image = clip_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
image = sampled_prev_image
def __call__(self, batch_size=1):
image = self.gaussian_sampler.sample(self.unet, batch_size=batch_size)
return image
......@@ -45,7 +45,7 @@ class DiffusionPipeline(Config):
config_name = "model_index.json"
def __init__(self, **kwargs):
def register_modules(self, **kwargs):
for name, module in kwargs.items():
# retrive library
library = module.__module__.split(".")[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment