Unverified Commit bd8df2da authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Pytorch] Pytorch only schedulers (#534)



* pytorch only schedulers

* fix style

* remove match_shape

* pytorch only ddpm

* remove SchedulerMixin

* remove numpy from karras_ve

* fix types

* remove numpy from lms_discrete

* remove numpy from pndm

* fix typo

* remove mixin and numpy from sde_vp and ve

* remove remaining tensor_format

* fix style

* sigmas has to be torch tensor

* removed set_format in readme

* remove set format from docs

* remove set_format from pipelines

* update tests

* fix typo

* continue to use mixin

* fix imports

* removed unsed imports

* match shape instead of assuming image shapes

* remove import typo

* update call to add_noise

* use math instead of numpy

* fix t_index

* removed commented out numpy tests

* timesteps needs to be discrete

* cast timesteps to int in flax scheduler too

* fix device mismatch issue

* small fix

* Update src/diffusers/schedulers/scheduling_pndm.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 3b747de8
...@@ -51,7 +51,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): ...@@ -51,7 +51,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
t1 = i / num_diffusion_timesteps t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32) return torch.tensor(betas, dtype=torch.float32)
class PNDMScheduler(SchedulerMixin, ConfigMixin): class PNDMScheduler(SchedulerMixin, ConfigMixin):
...@@ -86,7 +86,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -86,7 +86,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion. stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
""" """
...@@ -101,15 +100,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -101,15 +100,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps: bool = False, skip_prk_steps: bool = False,
set_alpha_to_one: bool = False, set_alpha_to_one: bool = False,
steps_offset: int = 0, steps_offset: int = 0,
tensor_format: str = "pt",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = np.asarray(trained_betas) self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model. # this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps) self.betas = betas_for_alpha_bar(num_train_timesteps)
...@@ -117,9 +117,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,9 +117,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# For now we only support F-PNDM, i.e. the runge-kutta method # For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
...@@ -139,9 +139,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -139,9 +139,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.plms_timesteps = None self.plms_timesteps = None
self.timesteps = None self.timesteps = None
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor: def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -189,13 +186,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -189,13 +186,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.ets = [] self.ets = []
self.counter = 0 self.counter = 0
self.set_format(tensor_format=self.tensor_format)
def step( def step(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: torch.FloatTensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -205,9 +201,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -205,9 +201,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
Args: Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
...@@ -224,9 +220,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -224,9 +220,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def step_prk( def step_prk(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: torch.FloatTensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -234,9 +230,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -234,9 +230,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
solution to the differential equation. solution to the differential equation.
Args: Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
...@@ -279,9 +275,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -279,9 +275,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def step_plms( def step_plms(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: torch.FloatTensor,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[SchedulerOutput, Tuple]:
""" """
...@@ -289,9 +285,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -289,9 +285,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
times to approximate the solution. times to approximate the solution.
Args: Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
...@@ -381,16 +377,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -381,16 +377,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def add_noise( def add_noise(
self, self,
original_samples: Union[torch.FloatTensor, np.ndarray], original_samples: torch.FloatTensor,
noise: Union[torch.FloatTensor, np.ndarray], noise: torch.FloatTensor,
timesteps: Union[torch.IntTensor, np.ndarray], timesteps: torch.IntTensor,
) -> torch.Tensor: ) -> torch.Tensor:
if self.tensor_format == "pt": if self.alphas_cumprod.device != original_samples.device:
timesteps = timesteps.to(self.alphas_cumprod.device) self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
import math
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
...@@ -65,7 +65,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -65,7 +65,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
epsilon. epsilon.
correct_steps (`int`): number of correction steps performed on a produced sample. correct_steps (`int`): number of correction steps performed on a produced sample.
tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
""" """
@register_to_config @register_to_config
...@@ -77,16 +76,12 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -77,16 +76,12 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sigma_max: float = 1348.0, sigma_max: float = 1348.0,
sampling_eps: float = 1e-5, sampling_eps: float = 1e-5,
correct_steps: int = 1, correct_steps: int = 1,
tensor_format: str = "pt",
): ):
# setable values # setable values
self.timesteps = None self.timesteps = None
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -98,13 +93,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -98,13 +93,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
""" """
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np": self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
elif tensor_format == "pt":
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_sigmas( def set_sigmas(
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
...@@ -129,28 +119,16 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -129,28 +119,16 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
if self.timesteps is None: if self.timesteps is None:
self.set_timesteps(num_inference_steps, sampling_eps) self.set_timesteps(num_inference_steps, sampling_eps)
tensor_format = getattr(self, "tensor_format", "pt") self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
if tensor_format == "np": self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
elif tensor_format == "pt":
self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def get_adjacent_sigma(self, timesteps, t): def get_adjacent_sigma(self, timesteps, t):
tensor_format = getattr(self, "tensor_format", "pt") return torch.where(
if tensor_format == "np": timesteps == 0,
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) torch.zeros_like(t.to(timesteps.device)),
elif tensor_format == "pt": self.discrete_sigmas[timesteps - 1].to(timesteps.device),
return torch.where( )
timesteps == 0,
torch.zeros_like(t.to(timesteps.device)),
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_seed(self, seed): def set_seed(self, seed):
warnings.warn( warnings.warn(
...@@ -158,19 +136,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -158,19 +136,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
" generator instead.", " generator instead.",
DeprecationWarning, DeprecationWarning,
) )
tensor_format = getattr(self, "tensor_format", "pt") torch.manual_seed(seed)
if tensor_format == "np":
np.random.seed(seed)
elif tensor_format == "pt":
torch.manual_seed(seed)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def step_pred( def step_pred(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: torch.FloatTensor,
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
**kwargs, **kwargs,
...@@ -180,9 +152,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -180,9 +152,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
Args: Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
generator: random number generator. generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
...@@ -210,18 +182,21 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -210,18 +182,21 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sigma = self.discrete_sigmas[timesteps].to(sample.device) sigma = self.discrete_sigmas[timesteps].to(sample.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
drift = self.zeros_like(sample) drift = torch.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods # also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * model_output diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape):
diffusion = diffusion.unsqueeze(-1)
drift = drift - diffusion**2 * model_output
# equation 6: sample noise for the diffusion term of # equation 6: sample noise for the diffusion term of
noise = self.randn_like(sample, generator=generator) noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise? # TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
if not return_dict: if not return_dict:
return (prev_sample, prev_sample_mean) return (prev_sample, prev_sample_mean)
...@@ -230,8 +205,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -230,8 +205,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def step_correct( def step_correct(
self, self,
model_output: Union[torch.FloatTensor, np.ndarray], model_output: torch.FloatTensor,
sample: Union[torch.FloatTensor, np.ndarray], sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
**kwargs, **kwargs,
...@@ -241,8 +216,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -241,8 +216,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
after making the prediction for the previous timestep. after making the prediction for the previous timestep.
Args: Args:
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. model_output (`torch.FloatTensor`): direct output from learned diffusion model.
sample (`torch.FloatTensor` or `np.ndarray`): sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
generator: random number generator. generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
...@@ -262,18 +237,21 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -262,18 +237,21 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction # sample noise for correction
noise = self.randn_like(sample, generator=generator) noise = torch.randn(sample.shape, layout=sample.layout, generator=generator).to(sample.device)
# compute step size from the model_output, the noise, and the snr # compute step size from the model_output, the noise, and the snr
grad_norm = self.norm(model_output) grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
noise_norm = self.norm(noise) noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
# self.repeat_scalar(step_size, sample.shape[0]) # self.repeat_scalar(step_size, sample.shape[0])
# compute corrected sample: model_output term and noise term # compute corrected sample: model_output term and noise term
prev_sample_mean = sample + step_size[:, None, None, None] * model_output step_size = step_size.flatten()
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise while len(step_size.shape) < len(sample.shape):
step_size = step_size.unsqueeze(-1)
prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
import numpy as np import math
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
...@@ -39,7 +40,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -39,7 +40,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
""" """
@register_to_config @register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
self.sigmas = None self.sigmas = None
self.discrete_sigmas = None self.discrete_sigmas = None
self.timesteps = None self.timesteps = None
...@@ -47,7 +48,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -47,7 +48,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def step_pred(self, score, x, t): def step_pred(self, score, x, t, generator=None):
if self.timesteps is None: if self.timesteps is None:
raise ValueError( raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
...@@ -59,20 +60,27 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): ...@@ -59,20 +60,27 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
) )
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
score = -score / std[:, None, None, None] std = std.flatten()
while len(std.shape) < len(score.shape):
std = std.unsqueeze(-1)
score = -score / std
# compute # compute
dt = -1.0 / len(self.timesteps) dt = -1.0 / len(self.timesteps)
beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
drift = -0.5 * beta_t[:, None, None, None] * x beta_t = beta_t.flatten()
while len(beta_t.shape) < len(x.shape):
beta_t = beta_t.unsqueeze(-1)
drift = -0.5 * beta_t * x
diffusion = torch.sqrt(beta_t) diffusion = torch.sqrt(beta_t)
drift = drift - diffusion[:, None, None, None] ** 2 * score drift = drift - diffusion**2 * score
x_mean = x + drift * dt x_mean = x + drift * dt
# add noise # add noise
noise = torch.randn_like(x) noise = torch.randn(x.shape, layout=x.layout, generator=generator).to(x.device)
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise x = x_mean + diffusion * math.sqrt(-dt) * noise
return x, x_mean return x, x_mean
......
...@@ -12,9 +12,7 @@ ...@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union
import numpy as np
import torch import torch
from ..utils import BaseOutput from ..utils import BaseOutput
...@@ -43,83 +41,3 @@ class SchedulerMixin: ...@@ -43,83 +41,3 @@ class SchedulerMixin:
""" """
config_name = SCHEDULER_CONFIG_NAME config_name = SCHEDULER_CONFIG_NAME
ignore_for_config = ["tensor_format"]
def set_format(self, tensor_format="pt"):
self.tensor_format = tensor_format
if tensor_format == "pt":
for key, value in vars(self).items():
if isinstance(value, np.ndarray):
setattr(self, key, torch.from_numpy(value))
return self
def clip(self, tensor, min_value=None, max_value=None):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.clip(tensor, min_value, max_value)
elif tensor_format == "pt":
return torch.clamp(tensor, min_value, max_value)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def log(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.log(tensor)
elif tensor_format == "pt":
return torch.log(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
"""
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
Args:
values: an array or tensor of values to extract.
broadcast_array: an array with a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
Returns:
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
tensor_format = getattr(self, "tensor_format", "pt")
values = values.flatten()
while len(values.shape) < len(broadcast_array.shape):
values = values[..., None]
if tensor_format == "pt":
values = values.to(broadcast_array.device)
return values
def norm(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.linalg.norm(tensor)
elif tensor_format == "pt":
return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean()
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def randn_like(self, tensor, generator=None):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.random.randn(*np.shape(tensor))
elif tensor_format == "pt":
# return torch.randn_like(tensor)
return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def zeros_like(self, tensor):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.zeros_like(tensor)
elif tensor_format == "pt":
return torch.zeros_like(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
...@@ -191,7 +191,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -191,7 +191,7 @@ class PipelineFastTests(unittest.TestCase):
def test_ddim(self): def test_ddim(self):
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDIMScheduler(tensor_format="pt") scheduler = DDIMScheduler()
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
...@@ -220,7 +220,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -220,7 +220,7 @@ class PipelineFastTests(unittest.TestCase):
def test_pndm_cifar10(self): def test_pndm_cifar10(self):
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = PNDMScheduler(tensor_format="pt") scheduler = PNDMScheduler()
pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device) pndm.to(torch_device)
...@@ -242,7 +242,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -242,7 +242,7 @@ class PipelineFastTests(unittest.TestCase):
def test_ldm_text2img(self): def test_ldm_text2img(self):
unet = self.dummy_cond_unet unet = self.dummy_cond_unet
scheduler = DDIMScheduler(tensor_format="pt") scheduler = DDIMScheduler()
vae = self.dummy_vae vae = self.dummy_vae
bert = self.dummy_text_encoder bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
...@@ -339,7 +339,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -339,7 +339,7 @@ class PipelineFastTests(unittest.TestCase):
def test_stable_diffusion_pndm(self): def test_stable_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True) scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae vae = self.dummy_vae
bert = self.dummy_text_encoder bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
...@@ -460,7 +460,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -460,7 +460,7 @@ class PipelineFastTests(unittest.TestCase):
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = ScoreSdeVeScheduler(tensor_format="pt") scheduler = ScoreSdeVeScheduler()
sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler) sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler)
sde_ve.to(torch_device) sde_ve.to(torch_device)
...@@ -484,7 +484,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -484,7 +484,7 @@ class PipelineFastTests(unittest.TestCase):
def test_ldm_uncond(self): def test_ldm_uncond(self):
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = DDIMScheduler(tensor_format="pt") scheduler = DDIMScheduler()
vae = self.dummy_vq_model vae = self.dummy_vq_model
ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler) ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
...@@ -512,7 +512,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -512,7 +512,7 @@ class PipelineFastTests(unittest.TestCase):
def test_karras_ve_pipeline(self): def test_karras_ve_pipeline(self):
unet = self.dummy_uncond_unet unet = self.dummy_uncond_unet
scheduler = KarrasVeScheduler(tensor_format="pt") scheduler = KarrasVeScheduler()
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler) pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
pipe.to(torch_device) pipe.to(torch_device)
...@@ -535,7 +535,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -535,7 +535,7 @@ class PipelineFastTests(unittest.TestCase):
def test_stable_diffusion_img2img(self): def test_stable_diffusion_img2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True) scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae vae = self.dummy_vae
bert = self.dummy_text_encoder bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
...@@ -646,7 +646,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -646,7 +646,7 @@ class PipelineFastTests(unittest.TestCase):
def test_stable_diffusion_inpaint(self): def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet unet = self.dummy_cond_unet
scheduler = PNDMScheduler(tensor_format="pt", skip_prk_steps=True) scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae vae = self.dummy_vae
bert = self.dummy_text_encoder bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
...@@ -842,7 +842,6 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -842,7 +842,6 @@ class PipelineTesterMixin(unittest.TestCase):
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDPMScheduler.from_config(model_id) scheduler = DDPMScheduler.from_config(model_id)
scheduler = scheduler.set_format("pt")
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
...@@ -882,7 +881,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -882,7 +881,7 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler(tensor_format="pt") scheduler = DDIMScheduler()
ddim = DDIMPipeline(unet=unet, scheduler=scheduler) ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
ddim.to(torch_device) ddim.to(torch_device)
...@@ -902,7 +901,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -902,7 +901,7 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = PNDMScheduler(tensor_format="pt") scheduler = PNDMScheduler()
pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device) pndm.to(torch_device)
...@@ -1043,8 +1042,8 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1043,8 +1042,8 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler(tensor_format="pt") ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler(tensor_format="pt") ddim_scheduler = DDIMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
...@@ -1067,8 +1066,8 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1067,8 +1066,8 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler(tensor_format="pt") ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler(tensor_format="pt") ddim_scheduler = DDIMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
...@@ -1093,7 +1092,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1093,7 +1092,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_karras_ve_pipeline(self): def test_karras_ve_pipeline(self):
model_id = "google/ncsnpp-celebahq-256" model_id = "google/ncsnpp-celebahq-256"
model = UNet2DModel.from_pretrained(model_id) model = UNet2DModel.from_pretrained(model_id)
scheduler = KarrasVeScheduler(tensor_format="pt") scheduler = KarrasVeScheduler()
pipe = KarrasVePipeline(unet=model, scheduler=scheduler) pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
pipe.to(torch_device) pipe.to(torch_device)
......
...@@ -173,34 +173,6 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -173,34 +173,6 @@ class SchedulerCommonTest(unittest.TestCase):
self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape) self.assertEqual(output_0.shape, output_1.shape)
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample_pt = self.dummy_sample
residual_pt = 0.1 * sample_pt
sample = sample_pt.numpy()
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_scheduler_outputs_equivalence(self): def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t): def set_nan_tensor_to_zero(t):
t[t != t] = 0 t[t != t] = 0
...@@ -266,7 +238,6 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -266,7 +238,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
"beta_schedule": "linear", "beta_schedule": "linear",
"variance_type": "fixed_small", "variance_type": "fixed_small",
"clip_sample": True, "clip_sample": True,
"tensor_format": "pt",
} }
config.update(**kwargs) config.update(**kwargs)
...@@ -305,10 +276,6 @@ class DDPMSchedulerTest(SchedulerCommonTest): ...@@ -305,10 +276,6 @@ class DDPMSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5 assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5 assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
# TODO Make DDPM Numpy compatible
def test_pytorch_equal_numpy(self):
pass
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
...@@ -387,7 +354,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -387,7 +354,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(5) scheduler.set_timesteps(5)
assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1])) assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all()
def test_betas(self): def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
...@@ -556,72 +523,6 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -556,72 +523,6 @@ class PNDMSchedulerTest(SchedulerCommonTest):
return sample return sample
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample_pt = self.dummy_sample
residual_pt = 0.1 * sample_pt
dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
sample = sample_pt.numpy()
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
scheduler.ets = dummy_past_residuals[:]
scheduler_pt.ets = dummy_past_residuals_pt[:]
output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_set_format(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
for key, value in vars(scheduler).items():
# we only allow `ets` attr to be a list
assert not isinstance(value, list) or key in [
"ets"
], f"Scheduler is not correctly set to np format, the attribute {key} is {type(value)}"
# check if `scheduler.set_format` does convert correctly attrs to pt format
for key, value in vars(scheduler_pt).items():
# we only allow `ets` attr to be a list
assert not isinstance(value, list) or key in [
"ets"
], f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
assert not isinstance(
value, np.ndarray
), f"Scheduler is not correctly set to pt format, the attribute {key} is {type(value)}"
def test_step_shape(self): def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
...@@ -667,12 +568,10 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -667,12 +568,10 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config(steps_offset=1) scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(10) scheduler.set_timesteps(10)
assert torch.equal( assert np.equal(
scheduler.timesteps, scheduler.timesteps,
torch.tensor( np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1] ).all()
),
)
def test_betas(self): def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
...@@ -786,7 +685,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -786,7 +685,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
"sigma_min": 0.01, "sigma_min": 0.01,
"sigma_max": 1348, "sigma_max": 1348,
"sampling_eps": 1e-5, "sampling_eps": 1e-5,
"tensor_format": "pt", # TODO add test for tensor formats
} }
config.update(**kwargs) config.update(**kwargs)
...@@ -936,7 +834,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -936,7 +834,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
"beta_end": 0.02, "beta_end": 0.02,
"beta_schedule": "linear", "beta_schedule": "linear",
"trained_betas": None, "trained_betas": None,
"tensor_format": "pt",
} }
config.update(**kwargs) config.update(**kwargs)
...@@ -958,28 +855,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -958,28 +855,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
for t in [0, 500, 800]: for t in [0, 500, 800]:
self.check_over_forward(time_step=t) self.check_over_forward(time_step=t)
def test_pytorch_equal_numpy(self):
for scheduler_class in self.scheduler_classes:
sample_pt = self.dummy_sample
residual_pt = 0.1 * sample_pt
sample = sample_pt.numpy()
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler_config["tensor_format"] = "np"
scheduler = scheduler_class(**scheduler_config)
scheduler_config["tensor_format"] = "pt"
scheduler_pt = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
scheduler_pt.set_timesteps(self.num_inference_steps)
output = scheduler.step(residual, 1, sample).prev_sample
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
...@@ -1001,5 +876,5 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -1001,5 +876,5 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_sum.item() - 1006.370) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3 assert abs(result_mean.item() - 1.31) < 1e-3
...@@ -41,7 +41,6 @@ class TrainingTests(unittest.TestCase): ...@@ -41,7 +41,6 @@ class TrainingTests(unittest.TestCase):
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
clip_sample=True, clip_sample=True,
tensor_format="pt",
) )
ddim_scheduler = DDIMScheduler( ddim_scheduler = DDIMScheduler(
num_train_timesteps=1000, num_train_timesteps=1000,
...@@ -49,7 +48,6 @@ class TrainingTests(unittest.TestCase): ...@@ -49,7 +48,6 @@ class TrainingTests(unittest.TestCase):
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
clip_sample=True, clip_sample=True,
tensor_format="pt",
) )
assert ddpm_scheduler.config.num_train_timesteps == ddim_scheduler.config.num_train_timesteps assert ddpm_scheduler.config.num_train_timesteps == ddim_scheduler.config.num_train_timesteps
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment