Unverified Commit 07c9a08e authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Add `timestep_spacing` and `steps_offset` to schedulers (#3947)



* Add timestep_spacing to DDPM, LMSDiscrete, PNDM.

* Remove spurious line.

* More easy schedulers.

* Add `linspace` to DDIM

* Noise sigma for `trailing`.

* Add timestep_spacing to DEISMultistepScheduler.

Not sure the range is the way it was intended.

* Fix: remove line used to debug.

* Support timestep_spacing in DPMSolverMultistep, DPMSolverSDE, UniPC

* Fix: convert to numpy.

* Use sched. defaults when instantiating from_config

For params not present in the original configuration.

This makes it possible to switch pipeline schedulers even if they use
different timestep_spacing (or any other param).

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Missing args in DPMSolverMultistep

* Test: default args not in config

* Style

* Fix scheduler name in test

* Remove duplicated entries

* Add test for solver_type

This test currently fails in main. When switching from DEIS to UniPC,
solver_type is "logrho" (the default value from DEIS), which gets
translated to "bh1" by UniPC. This is different to the default value for
UniPC: "bh2". This is where the translation happens: https://github.com/huggingface/diffusers/blob/36d22d0709dc19776e3016fb3392d0f5578b0ab2/src/diffusers/schedulers/scheduling_unipc_multistep.py#L171



* UniPC: use same default for solver_type

Fixes a bug when switching from UniPC from another scheduler (i.e.,
DEIS) that uses a different solver type. The solver is now the same as
if we had instantiated the scheduler directly.

* do not save use default values

* fix more

* fix all

* fix schedulers

* fix more

* finish for real

* finish for real

* flaky tests

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py

* Default steps_offset to 0.

* Add missing docstrings

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2837d490
...@@ -423,6 +423,10 @@ class ConfigMixin: ...@@ -423,6 +423,10 @@ class ConfigMixin:
@classmethod @classmethod
def extract_init_dict(cls, config_dict, **kwargs): def extract_init_dict(cls, config_dict, **kwargs):
# Skip keys that were not present in the original config, so default __init__ values were used
used_defaults = config_dict.get("_use_default_values", [])
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
# 0. Copy origin config dict # 0. Copy origin config dict
original_dict = dict(config_dict.items()) original_dict = dict(config_dict.items())
...@@ -544,8 +548,9 @@ class ConfigMixin: ...@@ -544,8 +548,9 @@ class ConfigMixin:
return value return value
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
# Don't save "_ignore_files" # Don't save "_ignore_files" or "_use_default_values"
config_dict.pop("_ignore_files", None) config_dict.pop("_ignore_files", None)
config_dict.pop("_use_default_values", None)
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
...@@ -599,6 +604,11 @@ def register_to_config(init): ...@@ -599,6 +604,11 @@ def register_to_config(init):
if k not in ignore and k not in new_kwargs if k not in ignore and k not in new_kwargs
} }
) )
# Take note of the parameters that were not present in the loaded config
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs)
new_kwargs = {**config_init_kwargs, **new_kwargs} new_kwargs = {**config_init_kwargs, **new_kwargs}
getattr(self, "register_to_config")(**new_kwargs) getattr(self, "register_to_config")(**new_kwargs)
init(self, *args, **init_kwargs) init(self, *args, **init_kwargs)
...@@ -643,6 +653,10 @@ def flax_register_to_config(cls): ...@@ -643,6 +653,10 @@ def flax_register_to_config(cls):
name = fields[i].name name = fields[i].name
new_kwargs[name] = arg new_kwargs[name] = arg
# Take note of the parameters that were not present in the loaded config
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
new_kwargs["_use_default_values"] = set(new_kwargs.keys()) - set(init_kwargs)
getattr(self, "register_to_config")(**new_kwargs) getattr(self, "register_to_config")(**new_kwargs)
original_init(self, *args, **kwargs) original_init(self, *args, **kwargs)
......
...@@ -302,8 +302,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -302,8 +302,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "leading": if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
......
...@@ -321,8 +321,15 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -321,8 +321,15 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "leading": if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
......
...@@ -114,6 +114,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -114,6 +114,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`): sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`. the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, default `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -134,6 +141,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -134,6 +141,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
dynamic_thresholding_ratio: float = 0.995, dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0, clip_sample_range: float = 1.0,
sample_max_value: float = 1.0, sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -228,11 +237,33 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -228,11 +237,33 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
) )
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.custom_timesteps = False self.custom_timesteps = False
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t, predicted_variance=None, variance_type=None): def _get_variance(self, t, predicted_variance=None, variance_type=None):
......
...@@ -116,6 +116,13 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -116,6 +116,13 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`): sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`. the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, default `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -138,6 +145,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -138,6 +145,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
dynamic_thresholding_ratio: float = 0.995, dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0, clip_sample_range: float = 1.0,
sample_max_value: float = 1.0, sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -234,11 +243,33 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -234,11 +243,33 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
) )
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.custom_timesteps = False self.custom_timesteps = False
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance
......
...@@ -107,6 +107,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -107,6 +107,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -129,6 +136,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -129,6 +136,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
solver_type: str = "logrho", solver_type: str = "logrho",
lower_order_final: bool = True, lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -185,12 +194,30 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -185,12 +194,30 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, optional): device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
timesteps = ( # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) if self.config.timestep_spacing == "linspace":
.round()[::-1][:-1] timesteps = (
.copy() np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.astype(np.int64) .round()[::-1][:-1]
) .copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas: if self.config.use_karras_sigmas:
......
...@@ -134,6 +134,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -134,6 +134,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the
Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on
diffusion ODEs. diffusion ODEs.
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -158,6 +165,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -158,6 +165,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -217,12 +226,29 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -217,12 +226,29 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Clipping the minimum of all lambda(t) for numerical stability. # Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule. # This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
timesteps = ( last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1] # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
.copy() if self.config.timestep_spacing == "linspace":
.astype(np.int64) timesteps = (
) np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas: if self.config.use_karras_sigmas:
......
...@@ -133,6 +133,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -133,6 +133,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
noise_sampler_seed (`int`, *optional*, defaults to `None`): noise_sampler_seed (`int`, *optional*, defaults to `None`):
The random seed to use for the noise sampler. If `None`, a random seed will be generated. The random seed to use for the noise sampler. If `None`, a random seed will be generated.
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -149,6 +156,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -149,6 +156,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
noise_sampler_seed: Optional[int] = None, noise_sampler_seed: Optional[int] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -187,6 +196,14 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -187,6 +196,14 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
pos = 0 pos = 0
return indices[pos].item() return indices[pos].item()
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
return (self.sigmas.max() ** 2 + 1) ** 0.5
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -226,7 +243,25 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -226,7 +243,25 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas) log_sigmas = np.log(sigmas)
...@@ -242,9 +277,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -242,9 +277,6 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.from_numpy(sigmas).to(device=device) sigmas = torch.from_numpy(sigmas).to(device=device)
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
timesteps = torch.from_numpy(timesteps) timesteps = torch.from_numpy(timesteps)
second_order_timesteps = torch.from_numpy(second_order_timesteps) second_order_timesteps = torch.from_numpy(second_order_timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
......
...@@ -99,7 +99,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -99,7 +99,13 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -114,6 +120,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -114,6 +120,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -137,15 +145,20 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,15 +145,20 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas) self.sigmas = torch.from_numpy(sigmas)
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps) self.timesteps = torch.from_numpy(timesteps)
self.is_scale_input_called = False self.is_scale_input_called = False
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
return (self.sigmas.max() ** 2 + 1) ** 0.5
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -179,7 +192,28 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -179,7 +192,28 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
::-1
].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
......
...@@ -107,6 +107,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -107,6 +107,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -123,6 +130,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -123,6 +130,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
interpolation_type: str = "linear", interpolation_type: str = "linear",
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -146,9 +155,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -146,9 +155,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas) self.sigmas = torch.from_numpy(sigmas)
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
...@@ -156,6 +162,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -156,6 +162,14 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = False self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
return (self.sigmas.max() ** 2 + 1) ** 0.5
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -191,7 +205,28 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -191,7 +205,28 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
::-1
].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas) log_sigmas = np.log(sigmas)
......
...@@ -78,6 +78,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -78,6 +78,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -93,6 +100,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -93,6 +100,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -128,6 +137,14 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -128,6 +137,14 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
pos = 0 pos = 0
return indices[pos].item() return indices[pos].item()
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
return (self.sigmas.max() ** 2 + 1) ** 0.5
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -166,7 +183,25 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -166,7 +183,25 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas) log_sigmas = np.log(sigmas)
...@@ -180,9 +215,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -180,9 +215,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.from_numpy(sigmas).to(device=device) sigmas = torch.from_numpy(sigmas).to(device=device)
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
timesteps = torch.from_numpy(timesteps) timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
......
...@@ -78,6 +78,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -78,6 +78,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -92,6 +99,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -92,6 +99,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -127,6 +136,14 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -127,6 +136,14 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
pos = 0 pos = 0
return indices[pos].item() return indices[pos].item()
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
return (self.sigmas.max() ** 2 + 1) ** 0.5
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -169,7 +186,25 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -169,7 +186,25 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
...@@ -197,9 +232,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -197,9 +232,6 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
if str(device).startswith("mps"): if str(device).startswith("mps"):
# mps does not support float64 # mps does not support float64
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
......
...@@ -77,6 +77,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -77,6 +77,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -91,6 +98,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -91,6 +98,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -126,6 +135,14 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -126,6 +135,14 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
pos = 0 pos = 0
return indices[pos].item() return indices[pos].item()
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
return (self.sigmas.max() ** 2 + 1) ** 0.5
def scale_model_input( def scale_model_input(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
...@@ -168,7 +185,25 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -168,7 +185,25 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
...@@ -185,9 +220,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -185,9 +220,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
[sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
) )
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
if str(device).startswith("mps"): if str(device).startswith("mps"):
# mps does not support float64 # mps does not support float64
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
......
...@@ -102,6 +102,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -102,6 +102,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -117,6 +124,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,6 +124,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -140,9 +149,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -140,9 +149,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas) self.sigmas = torch.from_numpy(sigmas)
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self.use_karras_sigmas = use_karras_sigmas self.use_karras_sigmas = use_karras_sigmas
...@@ -150,6 +156,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -150,6 +156,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.derivatives = [] self.derivatives = []
self.is_scale_input_called = False self.is_scale_input_called = False
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
return (self.sigmas.max() ** 2 + 1) ** 0.5
def scale_model_input( def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor: ) -> torch.FloatTensor:
...@@ -205,7 +219,27 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -205,7 +219,27 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
::-1
].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas) log_sigmas = np.log(sigmas)
......
...@@ -85,11 +85,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -85,11 +85,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, default `epsilon`, optional): prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process)
or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
timestep_spacing (`str`, default `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`): steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion. stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -106,6 +108,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -106,6 +108,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps: bool = False, skip_prk_steps: bool = False,
set_alpha_to_one: bool = False, set_alpha_to_one: bool = False,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
timestep_spacing: str = "leading",
steps_offset: int = 0, steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
...@@ -159,11 +162,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -159,11 +162,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
# creates integer timesteps by multiplying by ratio if self.config.timestep_spacing == "linspace":
# casting to int to avoid issues when num_inference_step is power of 3 self._timesteps = (
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64)
self._timesteps += self.config.steps_offset )
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
self._timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype(
np.int64
)
self._timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
if self.config.skip_prk_steps: if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to # for some models like stable diffusion the prk steps can/should be skipped to
......
...@@ -121,6 +121,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -121,6 +121,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
timestep_spacing (`str`, default `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
...@@ -145,6 +152,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -145,6 +152,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
disable_corrector: List[int] = [], disable_corrector: List[int] = [],
solver_p: SchedulerMixin = None, solver_p: SchedulerMixin = None,
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
...@@ -173,7 +182,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -173,7 +182,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if solver_type not in ["bh1", "bh2"]: if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]: if solver_type in ["midpoint", "heun", "logrho"]:
self.register_to_config(solver_type="bh1") self.register_to_config(solver_type="bh2")
else: else:
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
...@@ -199,12 +208,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -199,12 +208,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, optional): device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
timesteps = ( # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) if self.config.timestep_spacing == "linspace":
.round()[::-1][:-1] timesteps = (
.copy() np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
.astype(np.int64) .round()[::-1][:-1]
) .copy()
.astype(np.int64)
)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas: if self.config.use_karras_sigmas:
......
...@@ -75,6 +75,22 @@ class SampleObject3(ConfigMixin): ...@@ -75,6 +75,22 @@ class SampleObject3(ConfigMixin):
pass pass
class SampleObject4(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 5],
f=[5, 4],
):
pass
class ConfigTester(unittest.TestCase): class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self): def test_load_not_from_mixin(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
...@@ -137,6 +153,7 @@ class ConfigTester(unittest.TestCase): ...@@ -137,6 +153,7 @@ class ConfigTester(unittest.TestCase):
assert config.pop("c") == (2, 5) # instantiated as tuple assert config.pop("c") == (2, 5) # instantiated as tuple
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
config.pop("_use_default_values")
assert config == new_config assert config == new_config
def test_load_ddim_from_pndm(self): def test_load_ddim_from_pndm(self):
...@@ -233,3 +250,39 @@ class ConfigTester(unittest.TestCase): ...@@ -233,3 +250,39 @@ class ConfigTester(unittest.TestCase):
assert dpm.__class__ == DPMSolverMultistepScheduler assert dpm.__class__ == DPMSolverMultistepScheduler
# no warning should be thrown # no warning should be thrown
assert cap_logger.out == "" assert cap_logger.out == ""
def test_use_default_values(self):
# let's first save a config that should be in the form
# a=2,
# b=5,
# c=(2, 5),
# d="for diffusion",
# e=[1, 3],
config = SampleObject()
config_dict = {k: v for k, v in config.config.items() if not k.startswith("_")}
# make sure that default config has all keys in `_use_default_values`
assert set(config_dict.keys()) == config.config._use_default_values
with tempfile.TemporaryDirectory() as tmpdirname:
config.save_config(tmpdirname)
# now loading it with SampleObject2 should put f into `_use_default_values`
config = SampleObject2.from_config(tmpdirname)
assert "f" in config._use_default_values
assert config.f == [1, 3]
# now loading the config, should **NOT** use [1, 3] for `f`, but the default [1, 4] value
# **BECAUSE** it is part of `config._use_default_values`
new_config = SampleObject4.from_config(config.config)
assert new_config.f == [5, 4]
config.config._use_default_values.pop()
new_config_2 = SampleObject4.from_config(config.config)
assert new_config_2.f == [1, 3]
# Nevertheless "e" should still be correctly loaded to [1, 3] from SampleObject2 instead of defaulting to [1, 5]
assert new_config_2.e == [1, 3]
...@@ -186,7 +186,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli ...@@ -186,7 +186,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4886, 0.5586, 0.4476, 0.5053, 0.6013, 0.4737, 0.5538, 0.5100, 0.4927]) expected_slice = np.array([0.4024, 0.6510, 0.4901, 0.5378, 0.5813, 0.5622, 0.4795, 0.4467, 0.4952])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......
...@@ -101,7 +101,7 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -101,7 +101,7 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu()
sample = sample.to(torch_device) sample = sample.to(torch_device)
for t in scheduler.timesteps: for t in scheduler.timesteps:
...@@ -128,7 +128,7 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -128,7 +128,7 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu()
sample = sample.to(torch_device) sample = sample.to(torch_device)
for t in scheduler.timesteps: for t in scheduler.timesteps:
......
...@@ -47,7 +47,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -47,7 +47,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu()
sample = sample.to(torch_device) sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps): for i, t in enumerate(scheduler.timesteps):
...@@ -100,7 +100,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest): ...@@ -100,7 +100,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
model = self.dummy_model() model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma sample = self.dummy_sample_deter * scheduler.init_noise_sigma.cpu()
sample = sample.to(torch_device) sample = sample.to(torch_device)
for t in scheduler.timesteps: for t in scheduler.timesteps:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment