Unverified Commit 63c68d97 authored by Nathan Lambert's avatar Nathan Lambert Committed by GitHub
Browse files

VE/VP SDE updates (#90)



* improve comments for sde_ve scheduler, init tests

* more comments, tweaking pipelines

* timesteps --> num_training_timesteps, some comments

* merge cpu test, add m1 data

* fix scheduler tests with num_train_timesteps

* make np compatible, add tests for sde ve

* minor default variable fixes

* make style and fix-copies
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent ba3c9a9a
......@@ -18,9 +18,6 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.model.to(device)
# TODO(Patrick) move to scheduler config
n_steps = 1
x = torch.randn(*shape) * self.scheduler.config.sigma_max
x = x.to(device)
......@@ -30,7 +27,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
for i, t in enumerate(self.scheduler.timesteps):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
for _ in range(n_steps):
for _ in range(self.scheduler.correct_steps):
with torch.no_grad():
result = self.model(x, sigma_t)
......
......@@ -27,6 +27,7 @@ class ScoreSdeVpPipeline(DiffusionPipeline):
t = t * torch.ones(shape[0], device=device)
scaled_t = t * (num_inference_steps - 1)
# TODO add corrector
with torch.no_grad():
result = model(x, scaled_t)
......
......@@ -51,7 +51,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDIMScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
......@@ -62,7 +62,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
):
super().__init__()
self.register_to_config(
timesteps=timesteps,
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
......@@ -72,13 +72,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
)
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=np.float32) ** 2
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps)
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
......@@ -88,10 +88,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, self.config.timesteps)[::-1].copy()
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
......@@ -131,7 +128,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.timesteps // self.num_inference_steps
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
......@@ -183,4 +180,4 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples
def __len__(self):
return self.config.timesteps
return self.config.num_train_timesteps
......@@ -50,7 +50,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class DDPMScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
......@@ -62,7 +62,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
):
super().__init__()
self.register_to_config(
timesteps=timesteps,
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
......@@ -75,10 +75,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps)
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
......@@ -160,4 +160,4 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples
def __len__(self):
return self.config.timesteps
return self.config.num_train_timesteps
......@@ -50,7 +50,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
class PNDMScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
......@@ -58,17 +58,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
):
super().__init__()
self.register_to_config(
timesteps=timesteps,
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
)
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps)
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
......@@ -96,10 +96,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if num_inference_steps in self.prk_time_steps:
return self.prk_time_steps[num_inference_steps]
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
inference_step_times = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
......@@ -109,7 +111,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps]
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
inference_step_times = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
return self.time_steps[num_inference_steps]
......@@ -135,6 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
):
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
"""
t = timestep
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
......@@ -165,6 +173,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
):
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
"""
t = timestep
if len(self.ets) < 3:
raise ValueError(
......@@ -221,4 +233,4 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return prev_sample
def __len__(self):
return self.config.timesteps
return self.config.num_train_timesteps
......@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import pdb
import numpy as np
import torch
......@@ -24,61 +25,132 @@ from .scheduling_utils import SchedulerMixin
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"):
"""
The variance exploding stochastic differential equation (SDE) scheduler.
:param snr: coefficient weighting the step from the score sample (from the network) to the random noise. :param
sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
distribution of the data.
:param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to
epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format:
"np" or "pt" for the expected format of samples passed to the Scheduler.
"""
def __init__(
self,
num_train_timesteps=2000,
snr=0.15,
sigma_min=0.01,
sigma_max=1348,
sampling_eps=1e-5,
correct_steps=1,
tensor_format="pt",
):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
snr=snr,
sigma_min=sigma_min,
sigma_max=sigma_max,
sampling_eps=sampling_eps,
correct_steps=correct_steps,
)
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
# TODO - update step to be torch-independant
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
self.timesteps = np.linspace(1, self.config.sampling_eps, num_inference_steps)
elif tensor_format == "pt":
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_sigmas(self, num_inference_steps):
if self.timesteps is None:
self.set_timesteps(num_inference_steps)
self.discrete_sigmas = torch.exp(
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = torch.tensor(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
def step_pred(self, result, x, t):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
self.discrete_sigmas = np.exp(
np.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = np.array(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
elif tensor_format == "pt":
self.discrete_sigmas = torch.exp(
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = torch.tensor(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def get_adjacent_sigma(self, timesteps, t):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
elif tensor_format == "pt":
return torch.where(
timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device)
)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def step_pred(self, score, x, t):
"""
Predict the sample at the previous timestep by reversing the SDE.
"""
# TODO(Patrick) better comments + non-PyTorch
t = t * torch.ones(x.shape[0], device=x.device)
timestep = (t * (len(self.timesteps) - 1)).long()
sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device)
)
f = torch.zeros_like(x)
G = torch.sqrt(sigma**2 - adjacent_sigma**2)
t = self.repeat_scalar(t, x.shape[0])
timesteps = self.long((t * (len(self.timesteps) - 1)))
sigma = self.discrete_sigmas[timesteps]
adjacent_sigma = self.get_adjacent_sigma(timesteps, t)
drift = self.zeros_like(x)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
# equation 6 in the paper: the score modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * score
# equation 6: sample noise for the diffusion term of
noise = self.randn_like(x)
x_mean = x - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
x = x_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
return x, x_mean
f = f - G[:, None, None, None] ** 2 * result
def step_correct(self, score, x):
"""
Correct the predicted sample based on the output score of the network. This is often run repeatedly after
making the prediction for the previous timestep.
"""
# TODO(Patrick) non-PyTorch
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None] * z
return x, x_mean
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
noise = self.randn_like(x)
def step_correct(self, result, x):
# TODO(Patrick) better comments + non-PyTorch
noise = torch.randn_like(x)
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
# compute step size from the score, the noise, and the snr
grad_norm = self.norm(score)
noise_norm = self.norm(noise)
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(x.shape[0], device=x.device)
x_mean = x + step_size[:, None, None, None] * result
step_size = self.repeat_scalar(step_size, x.shape[0]) # * self.ones(x.shape[0], device=x.device)
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
# compute corrected sample: score term and noise term
x_mean = x + step_size[:, None, None, None] * score
x = x_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
return x
def __len__(self):
return self.config.num_train_timesteps
......@@ -24,9 +24,10 @@ from .scheduling_utils import SchedulerMixin
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
beta_min=beta_min,
beta_max=beta_max,
sampling_eps=sampling_eps,
......@@ -39,14 +40,14 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def step_pred(self, result, x, t):
def step_pred(self, score, x, t):
# TODO(Patrick) better comments + non-PyTorch
# postprocess model result
# postprocess model score
log_mean_coeff = (
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
)
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
result = -result / std[:, None, None, None]
score = -score / std[:, None, None, None]
# compute
dt = -1.0 / len(self.timesteps)
......@@ -54,11 +55,14 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
drift = -0.5 * beta_t[:, None, None, None] * x
diffusion = torch.sqrt(beta_t)
drift = drift - diffusion[:, None, None, None] ** 2 * result
drift = drift - diffusion[:, None, None, None] ** 2 * score
x_mean = x + drift * dt
# add noise
z = torch.randn_like(x)
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z
noise = torch.randn_like(x)
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise
return x, x_mean
def __len__(self):
return self.config.num_train_timesteps
......@@ -53,12 +53,22 @@ class SchedulerMixin:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def long(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.int64(tensor)
elif tensor_format == "pt":
return tensor.long()
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
"""
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
Args:
timesteps: an array or tensor of values to extract.
values: an array or tensor of values to extract.
broadcast_array: an array with a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
Returns:
......@@ -74,3 +84,39 @@ class SchedulerMixin:
values = values.to(broadcast_array.device)
return values
def norm(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.linalg.norm(tensor)
elif tensor_format == "pt":
return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean()
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def randn_like(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.random.randn(*np.shape(tensor))
elif tensor_format == "pt":
return torch.randn_like(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def repeat_scalar(self, tensor, count):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.repeat(tensor, count)
elif tensor_format == "pt":
return torch.repeat_interleave(tensor, count)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def zeros_like(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.zeros_like(tensor)
elif tensor_format == "pt":
return torch.zeros_like(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
......@@ -1087,11 +1087,16 @@ class PipelineTesterMixin(unittest.TestCase):
image = sde_ve(num_inference_steps=2)
if model.device.type == "cpu":
expected_image_sum = 3384805632.0
expected_image_mean = 1076.000732421875
# patrick's cpu
expected_image_sum = 3384805888.0
expected_image_mean = 1076.00085
# m1 mbp
# expected_image_sum = 3384805376.0
# expected_image_mean = 1076.000610351562
else:
expected_image_sum = 3382849024.0
expected_image_mean = 1075.3787841796875
expected_image_mean = 1075.3788
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
......@@ -1109,6 +1114,10 @@ class PipelineTesterMixin(unittest.TestCase):
expected_image_sum = 4183.2012
expected_image_mean = 1.3617
# on m1 mbp
# expected_image_sum = 4318.6729
# expected_image_mean = 1.4058
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
......
......@@ -12,15 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pdb
import tempfile
import unittest
import numpy as np
import torch
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler
torch.backends.cuda.matmul.allow_tf32 = False
......@@ -208,7 +207,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
def get_scheduler_config(self, **kwargs):
config = {
"timesteps": 1000,
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
......@@ -221,7 +220,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
def test_timesteps(self):
for timesteps in [1, 5, 100, 1000]:
self.check_over_configs(timesteps=timesteps)
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
......@@ -288,7 +287,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
def get_scheduler_config(self, **kwargs):
config = {
"timesteps": 1000,
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
......@@ -300,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
def test_timesteps(self):
for timesteps in [100, 500, 1000]:
self.check_over_configs(timesteps=timesteps)
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
......@@ -367,7 +366,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def get_scheduler_config(self, **kwargs):
config = {
"timesteps": 1000,
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
......@@ -431,11 +430,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def test_timesteps(self):
for timesteps in [100, 1000]:
self.check_over_configs(timesteps=timesteps)
self.check_over_configs(num_train_timesteps=timesteps)
def test_timesteps_pmls(self):
for timesteps in [100, 1000]:
self.check_over_configs_pmls(timesteps=timesteps)
self.check_over_configs_pmls(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
......@@ -507,3 +506,115 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 199.1169) < 1e-2
assert abs(result_mean.item() - 0.2593) < 1e-3
class ScoreSdeVeSchedulerTest(SchedulerCommonTest):
scheduler_classes = (ScoreSdeVeScheduler,)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 2000,
"snr": 0.15,
"sigma_min": 0.01,
"sigma_max": 1348,
"sampling_eps": 1e-5,
"tensor_format": "np", # TODO add test for tensor formats
}
config.update(**kwargs)
return config
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)
new_output = new_scheduler.step_correct(residual, sample, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step_pred(residual, sample, time_step, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)
new_output = new_scheduler.step_correct(residual, sample, **kwargs)
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
def test_timesteps(self):
for timesteps in [10, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_sigmas(self):
for sigma_min, sigma_max in zip([0.0001, 0.001, 0.01], [1, 100, 1000]):
self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max)
def test_time_indices(self):
for t in [1, 5, 10]:
self.check_over_forward(time_step=t)
def test_full_loop_no_noise(self):
np.random.seed(0)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 3
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_sigmas(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
sigma_t = scheduler.sigmas[i]
for _ in range(scheduler.correct_steps):
with torch.no_grad():
result = model(sample, sigma_t)
sample = scheduler.step_correct(result, sample)
with torch.no_grad():
result = model(sample, sigma_t)
sample, sample_mean = scheduler.step_pred(result, sample, t)
result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample))
assert abs(result_sum.item() - 10629923278.7104) < 1e-2
assert abs(result_mean.item() - 13841045.9358) < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment