Unverified Commit 63c68d97 authored by Nathan Lambert's avatar Nathan Lambert Committed by GitHub
Browse files

VE/VP SDE updates (#90)



* improve comments for sde_ve scheduler, init tests

* more comments, tweaking pipelines

* timesteps --> num_training_timesteps, some comments

* merge cpu test, add m1 data

* fix scheduler tests with num_train_timesteps

* make np compatible, add tests for sde ve

* minor default variable fixes

* make style and fix-copies
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent ba3c9a9a
...@@ -18,9 +18,6 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -18,9 +18,6 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.model.to(device) model = self.model.to(device)
# TODO(Patrick) move to scheduler config
n_steps = 1
x = torch.randn(*shape) * self.scheduler.config.sigma_max x = torch.randn(*shape) * self.scheduler.config.sigma_max
x = x.to(device) x = x.to(device)
...@@ -30,7 +27,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -30,7 +27,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
for i, t in enumerate(self.scheduler.timesteps): for i, t in enumerate(self.scheduler.timesteps):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
for _ in range(n_steps): for _ in range(self.scheduler.correct_steps):
with torch.no_grad(): with torch.no_grad():
result = self.model(x, sigma_t) result = self.model(x, sigma_t)
......
...@@ -27,6 +27,7 @@ class ScoreSdeVpPipeline(DiffusionPipeline): ...@@ -27,6 +27,7 @@ class ScoreSdeVpPipeline(DiffusionPipeline):
t = t * torch.ones(shape[0], device=device) t = t * torch.ones(shape[0], device=device)
scaled_t = t * (num_inference_steps - 1) scaled_t = t * (num_inference_steps - 1)
# TODO add corrector
with torch.no_grad(): with torch.no_grad():
result = model(x, scaled_t) result = model(x, scaled_t)
......
...@@ -51,7 +51,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -51,7 +51,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDIMScheduler(SchedulerMixin, ConfigMixin): class DDIMScheduler(SchedulerMixin, ConfigMixin):
def __init__( def __init__(
self, self,
timesteps=1000, num_train_timesteps=1000,
beta_start=0.0001, beta_start=0.0001,
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
...@@ -62,7 +62,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -62,7 +62,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
): ):
super().__init__() super().__init__()
self.register_to_config( self.register_to_config(
timesteps=timesteps, num_train_timesteps=num_train_timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
...@@ -72,13 +72,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -72,13 +72,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
) )
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model. # this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=np.float32) ** 2 self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps) self.betas = betas_for_alpha_bar(num_train_timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -88,10 +88,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -88,10 +88,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.timesteps = np.arange(0, self.config.timesteps)[::-1].copy() self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def _get_variance(self, timestep, prev_timestep): def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = self.alphas_cumprod[timestep]
...@@ -131,7 +128,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -131,7 +128,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_prev_sample -> "x_t-1" # - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1) # 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.timesteps // self.num_inference_steps prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 2. compute alphas, betas # 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t = self.alphas_cumprod[timestep]
...@@ -183,4 +180,4 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -183,4 +180,4 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
return self.config.timesteps return self.config.num_train_timesteps
...@@ -50,7 +50,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -50,7 +50,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDPMScheduler(SchedulerMixin, ConfigMixin): class DDPMScheduler(SchedulerMixin, ConfigMixin):
def __init__( def __init__(
self, self,
timesteps=1000, num_train_timesteps=1000,
beta_start=0.0001, beta_start=0.0001,
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
...@@ -62,7 +62,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -62,7 +62,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
): ):
super().__init__() super().__init__()
self.register_to_config( self.register_to_config(
timesteps=timesteps, num_train_timesteps=num_train_timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
...@@ -75,10 +75,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -75,10 +75,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
if trained_betas is not None: if trained_betas is not None:
self.betas = np.asarray(trained_betas) self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps) self.betas = betas_for_alpha_bar(num_train_timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -160,4 +160,4 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -160,4 +160,4 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
return self.config.timesteps return self.config.num_train_timesteps
...@@ -50,7 +50,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -50,7 +50,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class PNDMScheduler(SchedulerMixin, ConfigMixin): class PNDMScheduler(SchedulerMixin, ConfigMixin):
def __init__( def __init__(
self, self,
timesteps=1000, num_train_timesteps=1000,
beta_start=0.0001, beta_start=0.0001,
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
...@@ -58,17 +58,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -58,17 +58,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
): ):
super().__init__() super().__init__()
self.register_to_config( self.register_to_config(
timesteps=timesteps, num_train_timesteps=num_train_timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
) )
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps) self.betas = betas_for_alpha_bar(num_train_timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -96,10 +96,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -96,10 +96,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if num_inference_steps in self.prk_time_steps: if num_inference_steps in self.prk_time_steps:
return self.prk_time_steps[num_inference_steps] return self.prk_time_steps[num_inference_steps]
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps)) inference_step_times = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile( prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
) )
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1])) self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
...@@ -109,7 +111,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -109,7 +111,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if num_inference_steps in self.time_steps: if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps] return self.time_steps[num_inference_steps]
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps)) inference_step_times = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3])) self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
return self.time_steps[num_inference_steps] return self.time_steps[num_inference_steps]
...@@ -135,6 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -135,6 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps, num_inference_steps,
): ):
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
"""
t = timestep t = timestep
prk_time_steps = self.get_prk_time_steps(num_inference_steps) prk_time_steps = self.get_prk_time_steps(num_inference_steps)
...@@ -165,6 +173,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -165,6 +173,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps, num_inference_steps,
): ):
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
"""
t = timestep t = timestep
if len(self.ets) < 3: if len(self.ets) < 3:
raise ValueError( raise ValueError(
...@@ -221,4 +233,4 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -221,4 +233,4 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return prev_sample return prev_sample
def __len__(self): def __len__(self):
return self.config.timesteps return self.config.num_train_timesteps
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import pdb
import numpy as np import numpy as np
import torch import torch
...@@ -24,61 +25,132 @@ from .scheduling_utils import SchedulerMixin ...@@ -24,61 +25,132 @@ from .scheduling_utils import SchedulerMixin
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"): """
The variance exploding stochastic differential equation (SDE) scheduler.
:param snr: coefficient weighting the step from the score sample (from the network) to the random noise. :param
sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
distribution of the data.
:param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to
epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format:
"np" or "pt" for the expected format of samples passed to the Scheduler.
"""
def __init__(
self,
num_train_timesteps=2000,
snr=0.15,
sigma_min=0.01,
sigma_max=1348,
sampling_eps=1e-5,
correct_steps=1,
tensor_format="pt",
):
super().__init__() super().__init__()
self.register_to_config( self.register_to_config(
num_train_timesteps=num_train_timesteps,
snr=snr, snr=snr,
sigma_min=sigma_min, sigma_min=sigma_min,
sigma_max=sigma_max, sigma_max=sigma_max,
sampling_eps=sampling_eps, sampling_eps=sampling_eps,
correct_steps=correct_steps,
) )
self.sigmas = None self.sigmas = None
self.discrete_sigmas = None self.discrete_sigmas = None
self.timesteps = None self.timesteps = None
# TODO - update step to be torch-independant
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
self.timesteps = np.linspace(1, self.config.sampling_eps, num_inference_steps)
elif tensor_format == "pt":
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_sigmas(self, num_inference_steps): def set_sigmas(self, num_inference_steps):
if self.timesteps is None: if self.timesteps is None:
self.set_timesteps(num_inference_steps) self.set_timesteps(num_inference_steps)
self.discrete_sigmas = torch.exp( tensor_format = getattr(self, "tensor_format", "pt")
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) if tensor_format == "np":
) self.discrete_sigmas = np.exp(
self.sigmas = torch.tensor( np.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] )
) self.sigmas = np.array(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
def step_pred(self, result, x, t): )
elif tensor_format == "pt":
self.discrete_sigmas = torch.exp(
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = torch.tensor(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def get_adjacent_sigma(self, timesteps, t):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
elif tensor_format == "pt":
return torch.where(
timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device)
)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def step_pred(self, score, x, t):
"""
Predict the sample at the previous timestep by reversing the SDE.
"""
# TODO(Patrick) better comments + non-PyTorch # TODO(Patrick) better comments + non-PyTorch
t = t * torch.ones(x.shape[0], device=x.device) t = self.repeat_scalar(t, x.shape[0])
timestep = (t * (len(self.timesteps) - 1)).long() timesteps = self.long((t * (len(self.timesteps) - 1)))
sigma = self.discrete_sigmas.to(t.device)[timestep] sigma = self.discrete_sigmas[timesteps]
adjacent_sigma = torch.where( adjacent_sigma = self.get_adjacent_sigma(timesteps, t)
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device) drift = self.zeros_like(x)
) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
f = torch.zeros_like(x)
G = torch.sqrt(sigma**2 - adjacent_sigma**2) # equation 6 in the paper: the score modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * score
# equation 6: sample noise for the diffusion term of
noise = self.randn_like(x)
x_mean = x - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
x = x_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
return x, x_mean
f = f - G[:, None, None, None] ** 2 * result def step_correct(self, score, x):
"""
Correct the predicted sample based on the output score of the network. This is often run repeatedly after
making the prediction for the previous timestep.
"""
# TODO(Patrick) non-PyTorch
z = torch.randn_like(x) # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
x_mean = x - f # sample noise for correction
x = x_mean + G[:, None, None, None] * z noise = self.randn_like(x)
return x, x_mean
def step_correct(self, result, x): # compute step size from the score, the noise, and the snr
# TODO(Patrick) better comments + non-PyTorch grad_norm = self.norm(score)
noise = torch.randn_like(x) noise_norm = self.norm(noise)
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(x.shape[0], device=x.device) step_size = self.repeat_scalar(step_size, x.shape[0]) # * self.ones(x.shape[0], device=x.device)
x_mean = x + step_size[:, None, None, None] * result
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise # compute corrected sample: score term and noise term
x_mean = x + step_size[:, None, None, None] * score
x = x_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
return x return x
def __len__(self):
return self.config.num_train_timesteps
...@@ -24,9 +24,10 @@ from .scheduling_utils import SchedulerMixin ...@@ -24,9 +24,10 @@ from .scheduling_utils import SchedulerMixin
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
super().__init__() super().__init__()
self.register_to_config( self.register_to_config(
num_train_timesteps=num_train_timesteps,
beta_min=beta_min, beta_min=beta_min,
beta_max=beta_max, beta_max=beta_max,
sampling_eps=sampling_eps, sampling_eps=sampling_eps,
...@@ -39,14 +40,14 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -39,14 +40,14 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def step_pred(self, result, x, t): def step_pred(self, score, x, t):
# TODO(Patrick) better comments + non-PyTorch # TODO(Patrick) better comments + non-PyTorch
# postprocess model result # postprocess model score
log_mean_coeff = ( log_mean_coeff = (
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
) )
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
result = -result / std[:, None, None, None] score = -score / std[:, None, None, None]
# compute # compute
dt = -1.0 / len(self.timesteps) dt = -1.0 / len(self.timesteps)
...@@ -54,11 +55,14 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -54,11 +55,14 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
drift = -0.5 * beta_t[:, None, None, None] * x drift = -0.5 * beta_t[:, None, None, None] * x
diffusion = torch.sqrt(beta_t) diffusion = torch.sqrt(beta_t)
drift = drift - diffusion[:, None, None, None] ** 2 * result drift = drift - diffusion[:, None, None, None] ** 2 * score
x_mean = x + drift * dt x_mean = x + drift * dt
# add noise # add noise
z = torch.randn_like(x) noise = torch.randn_like(x)
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise
return x, x_mean return x, x_mean
def __len__(self):
return self.config.num_train_timesteps
...@@ -53,12 +53,22 @@ class SchedulerMixin: ...@@ -53,12 +53,22 @@ class SchedulerMixin:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def long(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.int64(tensor)
elif tensor_format == "pt":
return tensor.long()
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]): def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
""" """
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
Args: Args:
timesteps: an array or tensor of values to extract. values: an array or tensor of values to extract.
broadcast_array: an array with a larger shape of K dimensions with the batch broadcast_array: an array with a larger shape of K dimensions with the batch
dimension equal to the length of timesteps. dimension equal to the length of timesteps.
Returns: Returns:
...@@ -74,3 +84,39 @@ class SchedulerMixin: ...@@ -74,3 +84,39 @@ class SchedulerMixin:
values = values.to(broadcast_array.device) values = values.to(broadcast_array.device)
return values return values
def norm(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.linalg.norm(tensor)
elif tensor_format == "pt":
return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean()
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def randn_like(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.random.randn(*np.shape(tensor))
elif tensor_format == "pt":
return torch.randn_like(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def repeat_scalar(self, tensor, count):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.repeat(tensor, count)
elif tensor_format == "pt":
return torch.repeat_interleave(tensor, count)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def zeros_like(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.zeros_like(tensor)
elif tensor_format == "pt":
return torch.zeros_like(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
...@@ -1087,11 +1087,16 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1087,11 +1087,16 @@ class PipelineTesterMixin(unittest.TestCase):
image = sde_ve(num_inference_steps=2) image = sde_ve(num_inference_steps=2)
if model.device.type == "cpu": if model.device.type == "cpu":
expected_image_sum = 3384805632.0 # patrick's cpu
expected_image_mean = 1076.000732421875 expected_image_sum = 3384805888.0
expected_image_mean = 1076.00085
# m1 mbp
# expected_image_sum = 3384805376.0
# expected_image_mean = 1076.000610351562
else: else:
expected_image_sum = 3382849024.0 expected_image_sum = 3382849024.0
expected_image_mean = 1075.3787841796875 expected_image_mean = 1075.3788
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
...@@ -1109,6 +1114,10 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1109,6 +1114,10 @@ class PipelineTesterMixin(unittest.TestCase):
expected_image_sum = 4183.2012 expected_image_sum = 4183.2012
expected_image_mean = 1.3617 expected_image_mean = 1.3617
# on m1 mbp
# expected_image_sum = 4318.6729
# expected_image_mean = 1.4058
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
......
...@@ -12,15 +12,14 @@ ...@@ -12,15 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pdb
import tempfile import tempfile
import unittest import unittest
import numpy as np import numpy as np
import torch import torch
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -208,7 +207,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -208,7 +207,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
def get_scheduler_config(self, **kwargs): def get_scheduler_config(self, **kwargs):
config = { config = {
"timesteps": 1000, "num_train_timesteps": 1000,
"beta_start": 0.0001, "beta_start": 0.0001,
"beta_end": 0.02, "beta_end": 0.02,
"beta_schedule": "linear", "beta_schedule": "linear",
...@@ -221,7 +220,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -221,7 +220,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
def test_timesteps(self): def test_timesteps(self):
for timesteps in [1, 5, 100, 1000]: for timesteps in [1, 5, 100, 1000]:
self.check_over_configs(timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self): def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
...@@ -288,7 +287,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -288,7 +287,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
def get_scheduler_config(self, **kwargs): def get_scheduler_config(self, **kwargs):
config = { config = {
"timesteps": 1000, "num_train_timesteps": 1000,
"beta_start": 0.0001, "beta_start": 0.0001,
"beta_end": 0.02, "beta_end": 0.02,
"beta_schedule": "linear", "beta_schedule": "linear",
...@@ -300,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -300,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
def test_timesteps(self): def test_timesteps(self):
for timesteps in [100, 500, 1000]: for timesteps in [100, 500, 1000]:
self.check_over_configs(timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self): def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
...@@ -367,7 +366,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -367,7 +366,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def get_scheduler_config(self, **kwargs): def get_scheduler_config(self, **kwargs):
config = { config = {
"timesteps": 1000, "num_train_timesteps": 1000,
"beta_start": 0.0001, "beta_start": 0.0001,
"beta_end": 0.02, "beta_end": 0.02,
"beta_schedule": "linear", "beta_schedule": "linear",
...@@ -431,11 +430,11 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -431,11 +430,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def test_timesteps(self): def test_timesteps(self):
for timesteps in [100, 1000]: for timesteps in [100, 1000]:
self.check_over_configs(timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
def test_timesteps_pmls(self): def test_timesteps_pmls(self):
for timesteps in [100, 1000]: for timesteps in [100, 1000]:
self.check_over_configs_pmls(timesteps=timesteps) self.check_over_configs_pmls(num_train_timesteps=timesteps)
def test_betas(self): def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
...@@ -507,3 +506,115 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -507,3 +506,115 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 199.1169) < 1e-2 assert abs(result_sum.item() - 199.1169) < 1e-2
assert abs(result_mean.item() - 0.2593) < 1e-3 assert abs(result_mean.item() - 0.2593) < 1e-3
class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
scheduler_classes = (ScoreSdeVeScheduler,)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 2000,
"snr": 0.15,
"sigma_min": 0.01,
"sigma_max": 1348,
"sampling_eps": 1e-5,
"tensor_format": "np", # TODO add test for tensor formats
}
config.update(**kwargs)
return config
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)
new_output = new_scheduler.step_correct(residual, sample, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)
new_output = new_scheduler.step_correct(residual, sample, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
def test_timesteps(self):
for timesteps in [10, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_sigmas(self):
for sigma_min, sigma_max in zip([0.0001, 0.001, 0.01], [1, 100, 1000]):
self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max)
def test_time_indices(self):
for t in [1, 5, 10]:
self.check_over_forward(time_step=t)
def test_full_loop_no_noise(self):
np.random.seed(0)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 3
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_sigmas(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
sigma_t = scheduler.sigmas[i]
for _ in range(scheduler.correct_steps):
with torch.no_grad():
result = model(sample, sigma_t)
sample = scheduler.step_correct(result, sample)
with torch.no_grad():
result = model(sample, sigma_t)
sample, sample_mean = scheduler.step_pred(result, sample, t)
result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample))
assert abs(result_sum.item() - 10629923278.7104) < 1e-2
assert abs(result_mean.item() - 13841045.9358) < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment