Commit de810814 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

finish first version sde ve

parent bc2d586d
...@@ -10,7 +10,7 @@ from .modeling_utils import ModelMixin ...@@ -10,7 +10,7 @@ from .modeling_utils import ModelMixin
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin, VeSdeScheduler
if is_transformers_available(): if is_transformers_available():
......
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
# helpers functions # helpers functions
from ..modeling_utils import ModelMixin
from ..configuration_utils import ConfigMixin
import functools import functools
import math import math
import string import string
...@@ -28,16 +24,15 @@ import torch ...@@ -28,16 +24,15 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return upfirdn2d_native( return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
)
def upfirdn2d_native( def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, channel, in_h, in_w = input.shape _, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1) input = input.reshape(-1, in_h, in_w, 1)
...@@ -48,9 +43,7 @@ def upfirdn2d_native( ...@@ -48,9 +43,7 @@ def upfirdn2d_native(
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad( out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[ out = out[
:, :,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
...@@ -59,9 +52,7 @@ def upfirdn2d_native( ...@@ -59,9 +52,7 @@ def upfirdn2d_native(
] ]
out = out.permute(0, 3, 1, 2) out = out.permute(0, 3, 1, 2)
out = out.reshape( out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w) out = F.conv2d(out, w)
out = out.reshape( out = out.reshape(
...@@ -350,7 +341,7 @@ conv3x3 = ddpm_conv3x3 ...@@ -350,7 +341,7 @@ conv3x3 = ddpm_conv3x3
def _einsum(a, b, c, x, y): def _einsum(a, b, c, x, y):
einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y) return torch.einsum(einsum_str, x, y)
......
...@@ -5,6 +5,9 @@ from .pipeline_ddpm import DDPMPipeline ...@@ -5,6 +5,9 @@ from .pipeline_ddpm import DDPMPipeline
from .pipeline_pndm import PNDMPipeline from .pipeline_pndm import PNDMPipeline
# from .pipeline_score_sde import NCSNppPipeline
if is_transformers_available(): if is_transformers_available():
from .pipeline_glide import GlidePipeline from .pipeline_glide import GlidePipeline
from .pipeline_latent_diffusion import LatentDiffusionPipeline from .pipeline_latent_diffusion import LatentDiffusionPipeline
......
#!/usr/bin/env python3
import numpy as np
import torch
import PIL
from diffusers import DiffusionPipeline
# from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0)
class NCSNppPipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
def __call__(self, generator=None):
N = self.scheduler.config.N
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size)
model = torch.nn.DataParallel(self.model.to(device))
centered = False
n_steps = 1
# Initial sample
x = torch.randn(*shape) * self.scheduler.config.sigma_max
x = x.to(device)
for i in range(N):
sigma_t = self.scheduler.get_sigma_t(i) * torch.ones(shape[0], device=device)
for _ in range(n_steps):
with torch.no_grad():
result = model(x, sigma_t)
x = self.scheduler.step_correct(result, x)
with torch.no_grad():
result = model(x, sigma_t)
x, x_mean = self.scheduler.step_pred(result, x, i)
x = x_mean
if centered:
x = (x + 1.0) / 2.0
return x
pipeline = NCSNppPipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
x = pipeline()
# for 5 cifar10
# x_sum = 106071.9922
# x_mean = 34.52864456176758
# for 1000 cifar10
# x_sum = 461.9700
# x_mean = 0.1504
# for N=2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125
def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean)
def save_image(x):
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("../images/hey.png")
# save_image(x)
...@@ -21,3 +21,4 @@ from .scheduling_ddpm import DDPMScheduler ...@@ -21,3 +21,4 @@ from .scheduling_ddpm import DDPMScheduler
from .scheduling_grad_tts import GradTTSScheduler from .scheduling_grad_tts import GradTTSScheduler
from .scheduling_pndm import PNDMScheduler from .scheduling_pndm import PNDMScheduler
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
from .scheduling_ve_sde import VeSdeScheduler
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
class VeSdeScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"):
super().__init__()
self.register_to_config(
snr=snr,
sigma_min=sigma_min,
sigma_max=sigma_max,
N=N,
sampling_eps=sampling_eps,
)
# (PVP) - clean up with .config.
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.snr = snr
self.N = N
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.timesteps = torch.linspace(1, sampling_eps, N)
def get_sigma_t(self, t):
return self.sigma_min * (self.sigma_max / self.sigma_min) ** self.timesteps[t]
def step_pred(self, result, x, t):
t = self.timesteps[t] * torch.ones(x.shape[0], device=x.device)
timestep = (t * (self.N - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device)
)
f = torch.zeros_like(x)
G = torch.sqrt(sigma**2 - adjacent_sigma**2)
f = f - G[:, None, None, None] ** 2 * result
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None] * z
return x, x_mean
def step_correct(self, result, x):
noise = torch.randn_like(x)
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(x.shape[0], device=x.device)
x_mean = x + step_size[:, None, None, None] * result
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment