Commit 0183bf13 authored by Patrick von Platen's avatar Patrick von Platen
Browse files
parents f6e8c8c0 9a4d53a4
...@@ -226,6 +226,56 @@ image_pil = PIL.Image.fromarray(image_processed[0]) ...@@ -226,6 +226,56 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png") image_pil.save("test.png")
``` ```
#### **Example 1024x1024 image generation with SDE VE**
See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE.
```python
from diffusers import DiffusionPipeline
import torch
import PIL.Image
import numpy as np
torch.manual_seed(32)
score_sde_sv = DiffusionPipeline.from_pretrained("fusing/ffhq_ncsnpp")
# Note this might take up to 3 minutes on a GPU
image = score_sde_sv(num_inference_steps=2000)
image = image.permute(0, 2, 3, 1).cpu().numpy()
image = np.clip(image * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image[0])
# save image
image_pil.save("test.png")
```
#### **Example 32x32 image generation with SDE VP**
See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE.
```python
from diffusers import DiffusionPipeline
import torch
import PIL.Image
import numpy as np
torch.manual_seed(32)
score_sde_sv = DiffusionPipeline.from_pretrained("fusing/cifar10-ddpmpp-deep-vp")
# Note this might take up to 3 minutes on a GPU
image = score_sde_sv(num_inference_steps=1000)
image = image.permute(0, 2, 3, 1).cpu().numpy()
image = np.clip(image * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image[0])
# save image
image_pil.save("test.png")
```
#### **Text to Image generation with Latent Diffusion** #### **Text to Image generation with Latent Diffusion**
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._ _Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._
......
...@@ -9,8 +9,16 @@ __version__ = "0.0.4" ...@@ -9,8 +9,16 @@ __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline, ScoreSdeVpPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin, VeSdeScheduler from .schedulers import (
DDIMScheduler,
DDPMScheduler,
GradTTSScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
ScoreSdeVpScheduler,
)
if is_transformers_available(): if is_transformers_available():
......
...@@ -766,7 +766,6 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -766,7 +766,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
continuous=continuous, continuous=continuous,
) )
self.act = act = get_act(nonlinearity) self.act = act = get_act(nonlinearity)
# self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))
self.nf = nf self.nf = nf
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
...@@ -939,7 +938,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -939,7 +938,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.all_modules = nn.ModuleList(modules) self.all_modules = nn.ModuleList(modules)
def forward(self, x, time_cond): def forward(self, x, time_cond, sigmas=None):
# timestep/noise_level embedding; only for continuous training # timestep/noise_level embedding; only for continuous training
modules = self.all_modules modules = self.all_modules
m_idx = 0 m_idx = 0
...@@ -952,7 +951,7 @@ class NCSNpp(ModelMixin, ConfigMixin): ...@@ -952,7 +951,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
elif self.embedding_type == "positional": elif self.embedding_type == "positional":
# Sinusoidal positional embeddings. # Sinusoidal positional embeddings.
timesteps = time_cond timesteps = time_cond
used_sigmas = self.sigmas[time_cond.long()] used_sigmas = sigmas
temb = get_timestep_embedding(timesteps, self.nf) temb = get_timestep_embedding(timesteps, self.nf)
else: else:
......
...@@ -3,9 +3,11 @@ from .pipeline_bddm import BDDMPipeline ...@@ -3,9 +3,11 @@ from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIMPipeline from .pipeline_ddim import DDIMPipeline
from .pipeline_ddpm import DDPMPipeline from .pipeline_ddpm import DDPMPipeline
from .pipeline_pndm import PNDMPipeline from .pipeline_pndm import PNDMPipeline
from .pipeline_score_sde_ve import ScoreSdeVePipeline
from .pipeline_score_sde_vp import ScoreSdeVpPipeline
# from .pipeline_score_sde import NCSNppPipeline # from .pipeline_score_sde import ScoreSdeVePipeline
if is_transformers_available(): if is_transformers_available():
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import numpy as np
import torch import torch
import PIL
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
# from configs.ve import ffhq_ncsnpp_continuous as configs # TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
# from configs.ve import cifar10_ncsnpp_continuous as configs class ScoreSdeVePipeline(DiffusionPipeline):
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
torch.backends.cuda.matmul.allow_tf32 = False
torch.manual_seed(0)
class NCSNppPipeline(DiffusionPipeline):
def __init__(self, model, scheduler): def __init__(self, model, scheduler):
super().__init__() super().__init__()
self.register_modules(model=model, scheduler=scheduler) self.register_modules(model=model, scheduler=scheduler)
def __call__(self, generator=None): def __call__(self, num_inference_steps=2000, generator=None):
N = self.scheduler.config.N
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size img_size = self.model.config.image_size
channels = self.model.config.num_channels channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size) shape = (1, channels, img_size, img_size)
model = torch.nn.DataParallel(self.model.to(device)) model = self.model.to(device)
centered = False # TODO(Patrick) move to scheduler config
n_steps = 1 n_steps = 1
# Initial sample
x = torch.randn(*shape) * self.scheduler.config.sigma_max x = torch.randn(*shape) * self.scheduler.config.sigma_max
x = x.to(device) x = x.to(device)
for i in range(N): self.scheduler.set_timesteps(num_inference_steps)
sigma_t = self.scheduler.get_sigma_t(i) * torch.ones(shape[0], device=device) self.scheduler.set_sigmas(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
for _ in range(n_steps): for _ in range(n_steps):
with torch.no_grad(): with torch.no_grad():
result = model(x, sigma_t) result = self.model(x, sigma_t)
x = self.scheduler.step_correct(result, x) x = self.scheduler.step_correct(result, x)
with torch.no_grad(): with torch.no_grad():
result = model(x, sigma_t) result = model(x, sigma_t)
x, x_mean = self.scheduler.step_pred(result, x, i) x, x_mean = self.scheduler.step_pred(result, x, t)
x = x_mean
if centered:
x = (x + 1.0) / 2.0
return x
pipeline = NCSNppPipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
x = pipeline()
# for 5 cifar10
# x_sum = 106071.9922
# x_mean = 34.52864456176758
# for 1000 cifar10
# x_sum = 461.9700
# x_mean = 0.1504
# for N=2 for 1024
x_sum = 3382810112.0
x_mean = 1075.366455078125
def check_x_sum_x_mean(x, x_sum, x_mean):
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
check_x_sum_x_mean(x, x_sum, x_mean)
def save_image(x):
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("../images/hey.png")
# save_image(x) return x_mean
#!/usr/bin/env python3
import torch
from diffusers import DiffusionPipeline
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class ScoreSdeVpPipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
def __call__(self, num_inference_steps=1000, generator=None):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
channels = self.model.config.num_channels
shape = (1, channels, img_size, img_size)
model = self.model.to(device)
x = torch.randn(*shape).to(device)
self.scheduler.set_timesteps(num_inference_steps)
for t in self.scheduler.timesteps:
t = t * torch.ones(shape[0], device=device)
scaled_t = t * (num_inference_steps - 1)
with torch.no_grad():
result = model(x, scaled_t)
x, x_mean = self.scheduler.step_pred(result, x, t)
x_mean = (x_mean + 1.0) / 2.0
return x_mean
...@@ -20,5 +20,6 @@ from .scheduling_ddim import DDIMScheduler ...@@ -20,5 +20,6 @@ from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm import DDPMScheduler
from .scheduling_grad_tts import GradTTSScheduler from .scheduling_grad_tts import GradTTSScheduler
from .scheduling_pndm import PNDMScheduler from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
from .scheduling_ve_sde import VeSdeScheduler
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. # Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import numpy as np import numpy as np
import torch import torch
...@@ -21,34 +23,42 @@ from ..configuration_utils import ConfigMixin ...@@ -21,34 +23,42 @@ from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
class VeSdeScheduler(SchedulerMixin, ConfigMixin): class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"): def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"):
super().__init__() super().__init__()
self.register_to_config( self.register_to_config(
snr=snr, snr=snr,
sigma_min=sigma_min, sigma_min=sigma_min,
sigma_max=sigma_max, sigma_max=sigma_max,
N=N,
sampling_eps=sampling_eps, sampling_eps=sampling_eps,
) )
# (PVP) - clean up with .config.
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.snr = snr
self.N = N
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
self.timesteps = torch.linspace(1, sampling_eps, N)
def get_sigma_t(self, t): self.sigmas = None
return self.sigma_min * (self.sigma_max / self.sigma_min) ** self.timesteps[t] self.discrete_sigmas = None
self.timesteps = None
def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def set_sigmas(self, num_inference_steps):
if self.timesteps is None:
self.set_timesteps(num_inference_steps)
self.discrete_sigmas = torch.exp(
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = torch.tensor(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
def step_pred(self, result, x, t): def step_pred(self, result, x, t):
t = self.timesteps[t] * torch.ones(x.shape[0], device=x.device) # TODO(Patrick) better comments + non-PyTorch
t = t * torch.ones(x.shape[0], device=x.device)
timestep = (t * (len(self.timesteps) - 1)).long()
timestep = (t * (self.N - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep] sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where( adjacent_sigma = torch.where(
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device) timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device)
) )
f = torch.zeros_like(x) f = torch.zeros_like(x)
G = torch.sqrt(sigma**2 - adjacent_sigma**2) G = torch.sqrt(sigma**2 - adjacent_sigma**2)
...@@ -61,10 +71,11 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin): ...@@ -61,10 +71,11 @@ class VeSdeScheduler(SchedulerMixin, ConfigMixin):
return x, x_mean return x, x_mean
def step_correct(self, result, x): def step_correct(self, result, x):
# TODO(Patrick) better comments + non-PyTorch
noise = torch.randn_like(x) noise = torch.randn_like(x)
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (self.snr * noise_norm / grad_norm) ** 2 * 2 step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(x.shape[0], device=x.device) step_size = step_size * torch.ones(x.shape[0], device=x.device)
x_mean = x + step_size[:, None, None, None] * result x_mean = x + step_size[:, None, None, None] * result
......
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
super().__init__()
self.register_to_config(
beta_min=beta_min,
beta_max=beta_max,
sampling_eps=sampling_eps,
)
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def step_pred(self, result, x, t):
# TODO(Patrick) better comments + non-PyTorch
# postprocess model result
log_mean_coeff = (
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
)
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
result = -result / std[:, None, None, None]
# compute
dt = -1.0 / len(self.timesteps)
beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
drift = -0.5 * beta_t[:, None, None, None] * x
diffusion = torch.sqrt(beta_t)
drift = drift - diffusion[:, None, None, None] ** 2 * result
x_mean = x + drift * dt
# add noise
z = torch.randn_like(x)
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z
return x, x_mean
...@@ -33,8 +33,13 @@ from diffusers import ( ...@@ -33,8 +33,13 @@ from diffusers import (
GradTTSPipeline, GradTTSPipeline,
GradTTSScheduler, GradTTSScheduler,
LatentDiffusionPipeline, LatentDiffusionPipeline,
NCSNpp,
PNDMPipeline, PNDMPipeline,
PNDMScheduler, PNDMScheduler,
ScoreSdeVePipeline,
ScoreSdeVeScheduler,
ScoreSdeVpPipeline,
ScoreSdeVpScheduler,
UNetGradTTSModel, UNetGradTTSModel,
UNetLDMModel, UNetLDMModel,
UNetModel, UNetModel,
...@@ -721,6 +726,40 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -721,6 +726,40 @@ class PipelineTesterMixin(unittest.TestCase):
) )
assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2 assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_score_sde_ve_pipeline(self):
torch.manual_seed(0)
model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp")
scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
image = sde_ve(num_inference_steps=2)
expected_image_sum = 3382810112.0
expected_image_mean = 1075.366455078125
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
@slow
def test_score_sde_vp_pipeline(self):
model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp")
sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler)
torch.manual_seed(0)
image = sde_vp(num_inference_steps=10)
expected_image_sum = 4183.2012
expected_image_mean = 1.3617
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
def test_module_from_pipeline(self): def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4) model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12) noise_scheduler = DDPMScheduler(timesteps=12)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment