Commit dc6d0286 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add vp sampler

parent d5c527a4
...@@ -9,7 +9,7 @@ __version__ = "0.0.4" ...@@ -9,7 +9,7 @@ __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline, ScoreSdeVpPipeline
from .schedulers import ( from .schedulers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
...@@ -17,6 +17,7 @@ from .schedulers import ( ...@@ -17,6 +17,7 @@ from .schedulers import (
PNDMScheduler, PNDMScheduler,
SchedulerMixin, SchedulerMixin,
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
ScoreSdeVpScheduler,
) )
......
...@@ -766,7 +766,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -766,7 +766,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
continuous=continuous, continuous=continuous,
) )
self.act = act = get_act(nonlinearity) self.act = act = get_act(nonlinearity)
# self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) self.register_buffer('sigmas', torch.tensor(np.linspace(np.log(50), np.log(0.01), 10)))
self.nf = nf self.nf = nf
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
......
...@@ -4,6 +4,7 @@ from .pipeline_ddim import DDIMPipeline ...@@ -4,6 +4,7 @@ from .pipeline_ddim import DDIMPipeline
from .pipeline_ddpm import DDPMPipeline from .pipeline_ddpm import DDPMPipeline
from .pipeline_pndm import PNDMPipeline from .pipeline_pndm import PNDMPipeline
from .pipeline_score_sde_ve import ScoreSdeVePipeline from .pipeline_score_sde_ve import ScoreSdeVePipeline
from .pipeline_score_sde_vp import ScoreSdeVpPipeline
# from .pipeline_score_sde import ScoreSdeVePipeline # from .pipeline_score_sde import ScoreSdeVePipeline
......
File mode changed from 100755 to 100644
#!/usr/bin/env python3
import torch
from diffusers import DiffusionPipeline
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class ScoreSdeVpPipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
def __call__(self, num_inference_steps=1000, generator=None):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size)
beta_min, beta_max = 0.1, 20
model = self.model.to(device)
x = torch.randn(*shape).to(device)
self.scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps):
t = t * torch.ones(shape[0], device=device)
sigma_t = t * (num_inference_steps - 1)
with torch.no_grad():
result = model(x, sigma_t)
log_mean_coeff = -0.25 * t ** 2 * (beta_max - beta_min) - 0.5 * t * beta_min
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
result = -result / std[:, None, None, None]
x, x_mean = self.scheduler.step_pred(result, x, t)
x_mean = (x_mean + 1.) / 2.
return x_mean
...@@ -22,3 +22,4 @@ from .scheduling_grad_tts import GradTTSScheduler ...@@ -22,3 +22,4 @@ from .scheduling_grad_tts import GradTTSScheduler
from .scheduling_pndm import PNDMScheduler from .scheduling_pndm import PNDMScheduler
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
super().__init__()
self.register_to_config(
beta_min=beta_min,
beta_max=beta_max,
sampling_eps=sampling_eps,
)
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def step_pred(self, result, x, t):
dt = -1. / len(self.timesteps)
z = torch.randn_like(x)
beta_t = self.beta_min + t * (self.beta_max - self.beta_min)
drift = -0.5 * beta_t[:, None, None, None] * x
diffusion = torch.sqrt(beta_t)
drift = drift - diffusion[:, None, None, None] ** 2 * result
x_mean = x + drift * dt
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z
return x, x_mean
...@@ -38,6 +38,8 @@ from diffusers import ( ...@@ -38,6 +38,8 @@ from diffusers import (
PNDMScheduler, PNDMScheduler,
ScoreSdeVePipeline, ScoreSdeVePipeline,
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
ScoreSdeVpPipeline,
ScoreSdeVpScheduler,
UNetGradTTSModel, UNetGradTTSModel,
UNetLDMModel, UNetLDMModel,
UNetModel, UNetModel,
...@@ -741,6 +743,23 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -741,6 +743,23 @@ class PipelineTesterMixin(unittest.TestCase):
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
@slow
def test_score_sde_vp_pipeline(self):
model = NCSNpp.from_pretrained("/home/patrick/cifar10-ddpmpp-vp")
scheduler = ScoreSdeVpScheduler()
sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler)
torch.manual_seed(0)
image = sde_vp(num_inference_steps=10)
expected_image_sum = 4183.2012
expected_image_mean = 1.3617
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
def test_module_from_pipeline(self): def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4) model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12) noise_scheduler = DDPMScheduler(timesteps=12)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment