Unverified Commit d7dcba4a authored by Jonatan Kłosko's avatar Jonatan Kłosko Committed by GitHub
Browse files

Unify offset configuration in DDIM and PNDM schedulers (#479)



* Unify offset configuration in DDIM and PNDM schedulers

* Format

Add missing variables

* Fix pipeline test

* Update src/diffusers/schedulers/scheduling_ddim.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Default set_alpha_to_one to false

* Format

* Add tests

* Format

* add deprecation warning
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 9e439d8c
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
...@@ -53,6 +54,21 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -53,6 +54,21 @@ class StableDiffusionPipeline(DiffusionPipeline):
): ):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -217,12 +233,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -217,12 +233,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latents = latents.to(self.device) latents = latents.to(self.device)
# set timesteps # set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) self.scheduler.set_timesteps(num_inference_steps)
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler): if isinstance(self.scheduler, LMSDiscreteScheduler):
......
import inspect import inspect
import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
...@@ -7,6 +8,7 @@ import torch ...@@ -7,6 +8,7 @@ import torch
import PIL import PIL
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
...@@ -64,6 +66,21 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -64,6 +66,21 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
): ):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -169,14 +186,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -169,14 +186,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
# set timesteps # set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) self.scheduler.set_timesteps(num_inference_steps)
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
if isinstance(init_image, PIL.Image.Image): if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image) init_image = preprocess(init_image)
...@@ -190,6 +200,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -190,6 +200,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_latents = torch.cat([init_latents] * batch_size) init_latents = torch.cat([init_latents] * batch_size)
# get the original timestep using init_timestep # get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps) init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler): if isinstance(self.scheduler, LMSDiscreteScheduler):
......
import inspect import inspect
import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
...@@ -8,6 +9,7 @@ import PIL ...@@ -8,6 +9,7 @@ import PIL
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler from ...schedulers import DDIMScheduler, PNDMScheduler
...@@ -83,6 +85,21 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -83,6 +85,21 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -193,19 +210,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -193,19 +210,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
# set timesteps # set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) self.scheduler.set_timesteps(num_inference_steps)
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# preprocess image # preprocess image
if not isinstance(init_image, torch.FloatTensor): if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image) init_image = preprocess_image(init_image)
init_image.to(self.device) init_image = init_image.to(self.device)
# encode the init image into latents and scale the latents # encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image).latent_dist init_latent_dist = self.vae.encode(init_image).latent_dist
...@@ -220,7 +230,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -220,7 +230,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# preprocess mask # preprocess mask
if not isinstance(mask_image, torch.FloatTensor): if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image) mask_image = preprocess_mask(mask_image)
mask_image.to(self.device) mask_image = mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size) mask = torch.cat([mask_image] * batch_size)
# check sizes # check sizes
...@@ -228,6 +238,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -228,6 +238,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
raise ValueError("The mask and init_image should be the same size!") raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep # get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps) init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep] timesteps = self.scheduler.timesteps[-init_timestep]
......
...@@ -100,12 +100,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline): ...@@ -100,12 +100,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
# set timesteps # set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) self.scheduler.set_timesteps(num_inference_steps)
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler): if isinstance(self.scheduler, LMSDiscreteScheduler):
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# and https://github.com/hojonathanho/diffusion # and https://github.com/hojonathanho/diffusion
import math import math
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -78,7 +79,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -78,7 +79,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`): set_alpha_to_one (`bool`, default `True`):
if alpha for final step is 1 or the final alpha of the "non-previous" one. each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
""" """
...@@ -93,6 +100,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -93,6 +100,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
clip_sample: bool = True, clip_sample: bool = True,
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0,
tensor_format: str = "pt", tensor_format: str = "pt",
): ):
if trained_betas is not None: if trained_betas is not None:
...@@ -134,16 +142,26 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -134,16 +142,26 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def set_timesteps(self, num_inference_steps: int, offset: int = 0): def set_timesteps(self, num_inference_steps: int, **kwargs):
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
""" """
offset = self.config.steps_offset
if "offset" in kwargs:
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead.",
DeprecationWarning,
)
offset = kwargs["offset"]
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -74,10 +75,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -74,10 +75,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
`linear`, `scaled_linear`, or `squaredcos_cap_v2`. `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
skip_prk_steps (`bool`): skip_prk_steps (`bool`):
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
before plms steps; defaults to `False`. before plms steps; defaults to `False`.
set_alpha_to_one (`bool`, default `False`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
""" """
...@@ -89,8 +98,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -89,8 +98,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None, trained_betas: Optional[np.ndarray] = None,
tensor_format: str = "pt",
skip_prk_steps: bool = False, skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
tensor_format: str = "pt",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = np.asarray(trained_betas) self.betas = np.asarray(trained_betas)
...@@ -108,6 +119,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -108,6 +119,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# For now we only support F-PNDM, i.e. the runge-kutta method # For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2. # mainly at formula (9), (12), (13) and the Algorithm 2.
...@@ -122,7 +135,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -122,7 +135,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self._offset = 0
self.prk_timesteps = None self.prk_timesteps = None
self.plms_timesteps = None self.plms_timesteps = None
self.timesteps = None self.timesteps = None
...@@ -130,23 +142,31 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -130,23 +142,31 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor: def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
""" """
offset = self.config.steps_offset
if "offset" in kwargs:
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead."
)
offset = kwargs["offset"]
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist() self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
self._offset = offset self._timesteps += offset
self._timesteps = np.array([t + self._offset for t in self._timesteps])
if self.config.skip_prk_steps: if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to # for some models like stable diffusion the prk steps can/should be skipped to
...@@ -231,7 +251,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -231,7 +251,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
) )
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) prev_timestep = timestep - diff_to_prev
timestep = self.prk_timesteps[self.counter // 4 * 4] timestep = self.prk_timesteps[self.counter // 4 * 4]
if self.counter % 4 == 0: if self.counter % 4 == 0:
...@@ -293,7 +313,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -293,7 +313,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information." "for more information."
) )
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
if self.counter != 1: if self.counter != 1:
self.ets.append(model_output) self.ets.append(model_output)
...@@ -323,7 +343,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -323,7 +343,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9) # this function computes x_(t−δ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation # Note that x_t needs to be added to both sides of the equation
...@@ -336,8 +356,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -336,8 +356,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t # sample -> x_t
# model_output -> e_θ(x_t, t) # model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ) # prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset] alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
......
...@@ -357,10 +357,38 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -357,10 +357,38 @@ class DDIMSchedulerTest(SchedulerCommonTest):
config.update(**kwargs) config.update(**kwargs)
return config return config
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps, eta = 10, 0.0
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta).prev_sample
return sample
def test_timesteps(self): def test_timesteps(self):
for timesteps in [100, 500, 1000]: for timesteps in [100, 500, 1000]:
self.check_over_configs(num_train_timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
def test_steps_offset(self):
for steps_offset in [0, 1]:
self.check_over_configs(steps_offset=steps_offset)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(5)
assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1]))
def test_betas(self): def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end) self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
...@@ -398,26 +426,31 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -398,26 +426,31 @@ class DDIMSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5 assert torch.sum(torch.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] sample = self.full_loop()
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps, eta = 10, 0.0 result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
model = self.dummy_model() assert abs(result_sum.item() - 172.0067) < 1e-2
sample = self.dummy_sample_deter assert abs(result_mean.item() - 0.223967) < 1e-3
scheduler.set_timesteps(num_inference_steps) def test_full_loop_with_set_alpha_to_one(self):
for t in scheduler.timesteps: # We specify different beta, so that the first alpha is 0.99
residual = model(sample, t) sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
sample = scheduler.step(residual, t, sample, eta).prev_sample assert abs(result_sum.item() - 149.8295) < 1e-2
assert abs(result_mean.item() - 0.1951) < 1e-3
def test_full_loop_with_no_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 172.0067) < 1e-2 assert abs(result_sum.item() - 149.0784) < 1e-2
assert abs(result_mean.item() - 0.223967) < 1e-3 assert abs(result_mean.item() - 0.1941) < 1e-3
class PNDMSchedulerTest(SchedulerCommonTest): class PNDMSchedulerTest(SchedulerCommonTest):
...@@ -503,6 +536,26 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -503,6 +536,26 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.prk_timesteps):
residual = model(sample, t)
sample = scheduler.step_prk(residual, t, sample).prev_sample
for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t)
sample = scheduler.step_plms(residual, t, sample).prev_sample
return sample
def test_pytorch_equal_numpy(self): def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs) kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None) num_inference_steps = kwargs.pop("num_inference_steps", None)
...@@ -606,8 +659,23 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -606,8 +659,23 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for timesteps in [100, 1000]: for timesteps in [100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps) self.check_over_configs(num_train_timesteps=timesteps)
def test_steps_offset(self):
for steps_offset in [0, 1]:
self.check_over_configs(steps_offset=steps_offset)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(10)
assert torch.equal(
scheduler.timesteps,
torch.tensor(
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
),
)
def test_betas(self): def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end) self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self): def test_schedules(self):
...@@ -620,7 +688,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -620,7 +688,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def test_inference_steps(self): def test_inference_steps(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) self.check_over_forward(num_inference_steps=num_inference_steps)
def test_pow_of_3_inference_steps(self): def test_pow_of_3_inference_steps(self):
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
...@@ -648,28 +716,30 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -648,28 +716,30 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample
def test_full_loop_no_noise(self): def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0] sample = self.full_loop()
scheduler_config = self.get_scheduler_config() result_sum = torch.sum(torch.abs(sample))
scheduler = scheduler_class(**scheduler_config) result_mean = torch.mean(torch.abs(sample))
num_inference_steps = 10 assert abs(result_sum.item() - 198.1318) < 1e-2
model = self.dummy_model() assert abs(result_mean.item() - 0.2580) < 1e-3
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.prk_timesteps): def test_full_loop_with_set_alpha_to_one(self):
residual = model(sample, t) # We specify different beta, so that the first alpha is 0.99
sample = scheduler.step_prk(residual, i, sample).prev_sample sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
for i, t in enumerate(scheduler.plms_timesteps): assert abs(result_sum.item() - 230.0399) < 1e-2
residual = model(sample, t) assert abs(result_mean.item() - 0.2995) < 1e-3
sample = scheduler.step_plms(residual, i, sample).prev_sample
def test_full_loop_with_no_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
result_sum = torch.sum(torch.abs(sample)) result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 428.8788) < 1e-2 assert abs(result_sum.item() - 186.9482) < 1e-2
assert abs(result_mean.item() - 0.5584) < 1e-3 assert abs(result_mean.item() - 0.2434) < 1e-3
class ScoreSdeVeSchedulerTest(unittest.TestCase): class ScoreSdeVeSchedulerTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment