Unverified Commit b34be039 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

Karras VE, DDIM and DDPM flax schedulers (#508)

* beta never changes removed from state

* fix typos in docs

* removed unused var

* initial ddim flax scheduler

* import

* added dummy objects

* fix style

* fix typo

* docs

* fix typo in comment

* set return type

* added flax ddom

* fix style

* remake

* pass PRNG key as argument and split before use

* fix doc string

* use config

* added flax Karras VE scheduler

* make style

* fix dummy

* fix ndarray type annotation

* replace returns a new state

* added lms_discrete scheduler

* use self.config

* add_noise needs state

* use config

* use config

* docstring

* added flax score sde ve

* fix imports

* fix typos
parent 83a7bb2a
...@@ -504,7 +504,9 @@ def main(): ...@@ -504,7 +504,9 @@ def main():
noise = torch.randn(latents.shape).to(latents.device) noise = torch.randn(latents.shape).to(latents.device)
bsz = latents.shape[0] bsz = latents.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
).long()
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
......
...@@ -130,7 +130,7 @@ def main(args): ...@@ -130,7 +130,7 @@ def main(args):
bsz = clean_images.shape[0] bsz = clean_images.shape[0]
# Sample a random timestep for each image # Sample a random timestep for each image
timesteps = torch.randint( timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
).long() ).long()
# Add noise to the clean images according to the noise magnitude at each timestep # Add noise to the clean images according to the noise magnitude at each timestep
......
...@@ -64,6 +64,13 @@ else: ...@@ -64,6 +64,13 @@ else:
if is_flax_available(): if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin from .modeling_flax_utils import FlaxModelMixin
from .schedulers import FlaxPNDMScheduler from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
FlaxScoreSdeVeScheduler,
)
else: else:
from .utils.dummy_flax_objects import * # noqa F403 from .utils.dummy_flax_objects import * # noqa F403
...@@ -386,7 +386,7 @@ class FlaxModelMixin: ...@@ -386,7 +386,7 @@ class FlaxModelMixin:
raise ValueError from e raise ValueError from e
except (UnicodeDecodeError, ValueError): except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.arrays # make sure all arrays are stored as jnp.ndarray
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261 # https://github.com/google/flax/issues/1261
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
......
...@@ -80,7 +80,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -80,7 +80,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
# correction step # correction step
for _ in range(self.scheduler.correct_steps): for _ in range(self.scheduler.config.correct_steps):
model_output = self.unet(sample, sigma_t).sample model_output = self.unet(sample, sigma_t).sample
sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
......
...@@ -28,7 +28,12 @@ else: ...@@ -28,7 +28,12 @@ else:
from ..utils.dummy_pt_objects import * # noqa F403 from ..utils.dummy_pt_objects import * # noqa F403
if is_flax_available(): if is_flax_available():
from .scheduling_ddim_flax import FlaxDDIMScheduler
from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
else: else:
from ..utils.dummy_flax_objects import * # noqa F403 from ..utils.dummy_flax_objects import * # noqa F403
......
...@@ -113,7 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,7 +113,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# At every step in ddim, we are looking into the previous alphas_cumprod # At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0 # For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this paratemer simply to one or # `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one. # whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
...@@ -195,7 +195,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -195,7 +195,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_original_sample -> f_theta(x_t, t) or x_0 # - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t # - std_dev_t -> sigma_t
# - eta -> η # - eta -> η
# - pred_sample_direction -> "direction pointingc to x_t" # - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1" # - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1) # 1. get previous step value (=t-1)
......
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=jnp.float32)
@flax.struct.dataclass
class DDIMSchedulerState:
# setable values
timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, num_train_timesteps: int):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
state: DDIMSchedulerState
class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
if alpha for final step is 1 or the final alpha of the "non-previous" one.
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps)
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0
) -> DDIMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`DDIMSchedulerState`):
the `FlaxDDIMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
timesteps = timesteps + offset
return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
def step(
self,
state: DDIMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
eta: float = 0.0,
use_clipped_model_output: bool = False,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output:
# the model_output is always re-derived from the clipped x_0 in Glide
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
key = random.split(key, num=1)
noise = random.normal(key=key, shape=model_output.shape)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise(
self,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
...@@ -148,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -148,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
if variance_type is None: if variance_type is None:
variance_type = self.config.variance_type variance_type = self.config.variance_type
# hacks - were probs added for training stability # hacks - were probably added for training stability
if variance_type == "fixed_small": if variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20) variance = self.clip(variance, min_value=1e-20)
# for rl-diffuser https://arxiv.org/abs/2205.09991 # for rl-diffuser https://arxiv.org/abs/2205.09991
...@@ -187,7 +187,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -187,7 +187,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`): sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
predict_epsilon (`bool`): predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon. optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator. generator: random number generator.
......
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=jnp.float32)
@flax.struct.dataclass
class DDPMSchedulerState:
# setable values
timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, num_train_timesteps: int):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
state: DDPMSchedulerState
class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
variance_type (`str`):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.one = jnp.array(1.0)
self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps)
self.variance_type = variance_type
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`DDIMSchedulerState`):
the `FlaxDDPMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
timesteps = jnp.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps
)[::-1]
return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
if variance_type is None:
variance_type = self.config.variance_type
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = jnp.clip(variance, a_min=1e-20)
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log":
variance = jnp.log(jnp.clip(variance, a_min=1e-20))
elif variance_type == "fixed_large":
variance = self.betas[t]
elif variance_type == "fixed_large_log":
# Glide max_log
variance = jnp.log(self.betas[t])
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
min_log = variance
max_log = self.betas[t]
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
return variance
def step(
self,
state: DDPMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
predict_epsilon: bool = True,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
else:
predicted_variance = None
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
pred_original_sample = model_output
# 3. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
variance = 0
if t > 0:
key = random.split(key, num=1)
noise = random.normal(key=key, shape=model_output.shape)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
pred_prev_sample = pred_prev_sample + variance
if not return_dict:
return (pred_prev_sample, state)
return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state)
def add_noise(
self,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
...@@ -105,7 +105,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -105,7 +105,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [ self.schedule = [
(self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1))) (
self.config.sigma_max
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in self.timesteps for i in self.timesteps
] ]
self.schedule = np.array(self.schedule, dtype=np.float32) self.schedule = np.array(self.schedule, dtype=np.float32)
...@@ -121,13 +124,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -121,13 +124,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
TODO Args: TODO Args:
""" """
if self.s_min <= sigma <= self.s_max: if self.config.s_min <= sigma <= self.config.s_max:
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1) gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
else: else:
gamma = 0 gamma = 0
# sample eps ~ N(0, S_noise^2 * I) # sample eps ~ N(0, S_noise^2 * I)
eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
sigma_hat = sigma + gamma * sigma sigma_hat = sigma + gamma * sigma
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
......
# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@flax.struct.dataclass
class KarrasVeSchedulerState:
# setable values
num_inference_steps: Optional[int] = None
timesteps: Optional[jnp.ndarray] = None
schedule: Optional[jnp.ndarray] = None # sigma(t_i)
@classmethod
def create(cls):
return cls()
@dataclass
class FlaxKarrasVeOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Derivate of predicted original image sample (x_0).
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
"""
prev_sample: jnp.ndarray
derivative: jnp.ndarray
state: KarrasVeSchedulerState
class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
Args:
sigma_min (`float`): minimum noise magnitude
sigma_max (`float`): maximum noise magnitude
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
A reasonable range is [1.000, 1.011].
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
A reasonable range is [0, 100].
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
A reasonable range is [0, 10].
s_max (`float`): the end value of the sigma range where we add noise.
A reasonable range is [0.2, 80].
"""
@register_to_config
def __init__(
self,
sigma_min: float = 0.02,
sigma_max: float = 100,
s_noise: float = 1.007,
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
):
self.state = KarrasVeSchedulerState.create()
def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`KarrasVeSchedulerState`):
the `FlaxKarrasVeScheduler` state data class.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
timesteps = jnp.arange(0, num_inference_steps)[::-1].copy()
schedule = [
(
self.config.sigma_max
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in timesteps
]
return state.replace(
num_inference_steps=num_inference_steps,
schedule=jnp.array(schedule, dtype=jnp.float32),
timesteps=timesteps,
)
def add_noise_to_input(
self,
state: KarrasVeSchedulerState,
sample: jnp.ndarray,
sigma: float,
key: random.KeyArray,
) -> Tuple[jnp.ndarray, float]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
TODO Args:
"""
if self.config.s_min <= sigma <= self.config.s_max:
gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1)
else:
gamma = 0
# sample eps ~ N(0, S_noise^2 * I)
key = random.split(key, num=1)
eps = self.config.s_noise * random.normal(key=key, shape=sample.shape)
sigma_hat = sigma + gamma * sigma
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
return sample_hat, sigma_hat
def step(
self,
state: KarrasVeSchedulerState,
model_output: jnp.ndarray,
sigma_hat: float,
sigma_prev: float,
sample_hat: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxKarrasVeOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
pred_original_sample = sample_hat + sigma_hat * model_output
derivative = (sample_hat - pred_original_sample) / sigma_hat
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
if not return_dict:
return (sample_prev, derivative, state)
return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
def step_correct(
self,
state: KarrasVeSchedulerState,
model_output: jnp.ndarray,
sigma_hat: float,
sigma_prev: float,
sample_hat: jnp.ndarray,
sample_prev: jnp.ndarray,
derivative: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxKarrasVeOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. TODO complete description
Args:
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
"""
pred_original_sample = sample_prev + sigma_prev * model_output
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
if not return_dict:
return (sample_prev, derivative, state)
return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()
...@@ -113,7 +113,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -113,7 +113,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
low_idx = np.floor(self.timesteps).astype(int) low_idx = np.floor(self.timesteps).astype(int)
high_idx = np.ceil(self.timesteps).astype(int) high_idx = np.ceil(self.timesteps).astype(int)
......
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@flax.struct.dataclass
class LMSDiscreteSchedulerState:
# setable values
num_inference_steps: Optional[int] = None
timesteps: Optional[jnp.ndarray] = None
sigmas: Optional[jnp.ndarray] = None
derivatives: jnp.ndarray = jnp.array([])
@classmethod
def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], sigmas=sigmas)
@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
state: LMSDiscreteSchedulerState
class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.state = LMSDiscreteSchedulerState.create(
num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
)
def get_lms_coefficient(self, state, order, t, current_order):
"""
Compute a linear multistep coefficient.
Args:
order (TODO):
t (TODO):
current_order (TODO):
"""
def lms_derivative(tau):
prod = 1.0
for k in range(order):
if current_order == k:
continue
prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k])
return prod
integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0]
return integrated_coeff
def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`LMSDiscreteSchedulerState`):
the `FlaxLMSDiscreteScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=jnp.float32)
low_idx = jnp.floor(timesteps).astype(int)
high_idx = jnp.ceil(timesteps).astype(int)
frac = jnp.mod(timesteps, 1.0)
sigmas = jnp.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
sigmas = jnp.concatenate([sigmas, jnp.array([0.0])]).astype(jnp.float32)
return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
derivatives=jnp.array([]),
sigmas=sigmas,
)
def step(
self,
state: LMSDiscreteSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
order: int = 4,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
sigma = state.sigmas[timestep]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
state = state.replace(derivatives=state.derivatives.append(derivative))
if len(state.derivatives) > order:
state = state.replace(derivatives=state.derivatives.pop(0))
# 3. Compute linear multistep coefficients
order = min(timestep + 1, order)
lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]
# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives))
)
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise(
self,
state: LMSDiscreteSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sigmas = self.match_shape(state.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
...@@ -108,8 +108,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -108,8 +108,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)
# For now we only support F-PNDM, i.e. the runge-kutta method # For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2. # mainly at formula (9), (12), (13) and the Algorithm 2.
......
...@@ -25,7 +25,7 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -25,7 +25,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
""" """
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1]. (1-beta) over time from t = [0,1].
...@@ -40,7 +40,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999): ...@@ -40,7 +40,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999):
prevent singularities. prevent singularities.
Returns: Returns:
betas (`jnp.array`): the betas used by the scheduler to step the model outputs betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
""" """
def alpha_bar(time_step): def alpha_bar(time_step):
...@@ -56,36 +56,23 @@ def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999): ...@@ -56,36 +56,23 @@ def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999):
@flax.struct.dataclass @flax.struct.dataclass
class PNDMSchedulerState: class PNDMSchedulerState:
betas: jnp.array
# setable values # setable values
_timesteps: jnp.array _timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None num_inference_steps: Optional[int] = None
_offset: int = 0 _offset: int = 0
prk_timesteps: Optional[jnp.array] = None prk_timesteps: Optional[jnp.ndarray] = None
plms_timesteps: Optional[jnp.array] = None plms_timesteps: Optional[jnp.ndarray] = None
timesteps: Optional[jnp.array] = None timesteps: Optional[jnp.ndarray] = None
# running values # running values
cur_model_output: Optional[jnp.ndarray] = None cur_model_output: Optional[jnp.ndarray] = None
counter: int = 0 counter: int = 0
cur_sample: Optional[jnp.ndarray] = None cur_sample: Optional[jnp.ndarray] = None
ets: jnp.array = jnp.array([]) ets: jnp.ndarray = jnp.array([])
@property
def alphas(self) -> jnp.array:
return 1.0 - self.betas
@property
def alphas_cumprod(self) -> jnp.array:
return jnp.cumprod(self.alphas, axis=0)
@classmethod @classmethod
def create(cls, betas: jnp.array, num_train_timesteps: int): def create(cls, num_train_timesteps: int):
return cls( return cls(_timesteps=jnp.arange(0, num_train_timesteps)[::-1])
betas=betas,
_timesteps=jnp.arange(0, num_train_timesteps)[::-1],
)
@dataclass @dataclass
...@@ -112,7 +99,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -112,7 +99,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule (`str`): beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`. `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional): trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
skip_prk_steps (`bool`): skip_prk_steps (`bool`):
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
...@@ -126,28 +113,31 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -126,28 +113,31 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001, beta_start: float = 0.0001,
beta_end: float = 0.02, beta_end: float = 0.02,
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[jnp.array] = None, trained_betas: Optional[jnp.ndarray] = None,
skip_prk_steps: bool = False, skip_prk_steps: bool = False,
): ):
if trained_betas is not None: if trained_betas is not None:
betas = jnp.asarray(trained_betas) self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear": if beta_schedule == "linear":
betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model. # this schedule is very specific to the latent diffusion model.
betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule # Glide cosine schedule
betas = betas_for_alpha_bar(num_train_timesteps) self.betas = betas_for_alpha_bar(num_train_timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# For now we only support F-PNDM, i.e. the runge-kutta method # For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2. # mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4 self.pndm_order = 4
self.state = PNDMSchedulerState.create(betas=betas, num_train_timesteps=num_train_timesteps) self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)
def set_timesteps( def set_timesteps(
self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0 self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0
...@@ -157,7 +147,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -157,7 +147,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
state (`PNDMSchedulerState`): state (`PNDMSchedulerState`):
the PNDMScheduler state data class instance. the `FlaxPNDMScheduler` state data class instance.
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`): offset (`int`):
...@@ -165,7 +155,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -165,7 +155,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
""" """
step_ratio = self.config.num_train_timesteps // num_inference_steps step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # rounding to avoid issues when num_inference_step is power of 3
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
_timesteps = _timesteps + offset _timesteps = _timesteps + offset
...@@ -212,7 +202,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -212,7 +202,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
Args: Args:
state (`PNDMSchedulerState`): the PNDMScheduler state data class instance. state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model. model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
...@@ -246,7 +236,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -246,7 +236,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
solution to the differential equation. solution to the differential equation.
Args: Args:
state (`PNDMSchedulerState`): the PNDMScheduler state data class instance. state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model. model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
...@@ -268,24 +258,24 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -268,24 +258,24 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timestep = state.prk_timesteps[state.counter // 4 * 4] timestep = state.prk_timesteps[state.counter // 4 * 4]
if state.counter % 4 == 0: if state.counter % 4 == 0:
state.replace( state = state.replace(
cur_model_output=state.cur_model_output + 1 / 6 * model_output, cur_model_output=state.cur_model_output + 1 / 6 * model_output,
ets=state.ets.append(model_output), ets=state.ets.append(model_output),
cur_sample=sample, cur_sample=sample,
) )
elif (self.counter - 1) % 4 == 0: elif (self.counter - 1) % 4 == 0:
state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 2) % 4 == 0: elif (self.counter - 2) % 4 == 0:
state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 3) % 4 == 0: elif (self.counter - 3) % 4 == 0:
model_output = state.cur_model_output + 1 / 6 * model_output model_output = state.cur_model_output + 1 / 6 * model_output
state.replace(cur_model_output=0) state = state.replace(cur_model_output=0)
# cur_sample should not be `None` # cur_sample should not be `None`
cur_sample = state.cur_sample if state.cur_sample is not None else sample cur_sample = state.cur_sample if state.cur_sample is not None else sample
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state) prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state)
state.replace(counter=state.counter + 1) state = state.replace(counter=state.counter + 1)
if not return_dict: if not return_dict:
return (prev_sample, state) return (prev_sample, state)
...@@ -305,7 +295,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -305,7 +295,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
times to approximate the solution. times to approximate the solution.
Args: Args:
state (`PNDMSchedulerState`): the PNDMScheduler state data class instance. state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model. model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
...@@ -333,18 +323,18 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -333,18 +323,18 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0) prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0)
if state.counter != 1: if state.counter != 1:
state.replace(ets=state.ets.append(model_output)) state = state.replace(ets=state.ets.append(model_output))
else: else:
prev_timestep = timestep prev_timestep = timestep
timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
if len(state.ets) == 1 and state.counter == 0: if len(state.ets) == 1 and state.counter == 0:
model_output = model_output model_output = model_output
state.replace(cur_sample=sample) state = state.replace(cur_sample=sample)
elif len(state.ets) == 1 and state.counter == 1: elif len(state.ets) == 1 and state.counter == 1:
model_output = (model_output + state.ets[-1]) / 2 model_output = (model_output + state.ets[-1]) / 2
sample = state.cur_sample sample = state.cur_sample
state.replace(cur_sample=None) state = state.replace(cur_sample=None)
elif len(state.ets) == 2: elif len(state.ets) == 2:
model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
elif len(state.ets) == 3: elif len(state.ets) == 3:
...@@ -355,7 +345,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -355,7 +345,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
) )
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state) prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state)
state.replace(counter=state.counter + 1) state = state.replace(counter=state.counter + 1)
if not return_dict: if not return_dict:
return (prev_sample, state) return (prev_sample, state)
...@@ -375,8 +365,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -375,8 +365,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t # sample -> x_t
# model_output -> e_θ(x_t, t) # model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ) # prev_sample -> x_(t−δ)
alpha_prod_t = state.alphas_cumprod[timestep + 1 - state._offset] alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset]
alpha_prod_t_prev = state.alphas_cumprod[timestep_prev + 1 - state._offset] alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset]
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
...@@ -400,14 +390,13 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -400,14 +390,13 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
def add_noise( def add_noise(
self, self,
state: PNDMSchedulerState,
original_samples: jnp.ndarray, original_samples: jnp.ndarray,
noise: jnp.ndarray, noise: jnp.ndarray,
timesteps: jnp.ndarray, timesteps: jnp.ndarray,
) -> jnp.ndarray: ) -> jnp.ndarray:
sqrt_alpha_prod = state.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - state.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
......
...@@ -55,6 +55,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -55,6 +55,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin.from_config`] functions. [`~ConfigMixin.from_config`] functions.
Args: Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
snr (`float`): snr (`float`):
coefficient weighting the step from the model_output sample (from the network) to the random noise. coefficient weighting the step from the model_output sample (from the network) to the random noise.
sigma_min (`float`): sigma_min (`float`):
......
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@flax.struct.dataclass
class ScoreSdeVeSchedulerState:
# setable values
timesteps: Optional[jnp.ndarray] = None
discrete_sigmas: Optional[jnp.ndarray] = None
sigmas: Optional[jnp.ndarray] = None
@classmethod
def create(cls):
return cls()
@dataclass
class FlaxSdeVeOutput(SchedulerOutput):
"""
Output class for the ScoreSdeVeScheduler's step function output.
Args:
state (`ScoreSdeVeSchedulerState`):
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
"""
state: ScoreSdeVeSchedulerState
prev_sample: jnp.ndarray
prev_sample_mean: Optional[jnp.ndarray] = None
class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
"""
The variance exploding stochastic differential equation (SDE) scheduler.
For more information, see the original paper: https://arxiv.org/abs/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
snr (`float`):
coefficient weighting the step from the model_output sample (from the network) to the random noise.
sigma_min (`float`):
initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
distribution of the data.
sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
epsilon.
correct_steps (`int`): number of correction steps performed on a produced sample.
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 2000,
snr: float = 0.15,
sigma_min: float = 0.01,
sigma_max: float = 1348.0,
sampling_eps: float = 1e-5,
correct_steps: int = 1,
):
state = ScoreSdeVeSchedulerState.create()
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def set_timesteps(
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None
) -> ScoreSdeVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
"""
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
timesteps = jnp.linspace(1, sampling_eps, num_inference_steps)
return state.replace(timesteps=timesteps)
def set_sigmas(
self,
state: ScoreSdeVeSchedulerState,
num_inference_steps: int,
sigma_min: float = None,
sigma_max: float = None,
sampling_eps: float = None,
) -> ScoreSdeVeSchedulerState:
"""
Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
The sigmas control the weight of the `drift` and `diffusion` components of sample update.
Args:
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
sigma_min (`float`, optional):
initial noise scale value (overrides value given at Scheduler instantiation).
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
"""
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
if state.timesteps is None:
state = self.set_timesteps(state, num_inference_steps, sampling_eps)
discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps))
sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps])
return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas)
def get_adjacent_sigma(self, state, timesteps, t):
return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1])
def step_pred(
self,
state: ScoreSdeVeSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
return_dict: bool = True,
) -> Union[FlaxSdeVeOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
if state.timesteps is None:
raise ValueError(
"`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
timestep = timestep * jnp.ones(
sample.shape[0],
)
timesteps = (timestep * (len(state.timesteps) - 1)).long()
sigma = state.discrete_sigmas[timesteps]
adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep)
drift = jnp.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
# equation 6: sample noise for the diffusion term of
key = random.split(key, num=1)
noise = random.normal(key=key, shape=sample.shape)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
if not return_dict:
return (prev_sample, prev_sample_mean, state)
return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state)
def step_correct(
self,
state: ScoreSdeVeSchedulerState,
model_output: jnp.ndarray,
sample: jnp.ndarray,
key: random.KeyArray,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
Args:
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
if state.timesteps is None:
raise ValueError(
"`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
key = random.split(key, num=1)
noise = random.normal(key=key, shape=sample.shape)
# compute step size from the model_output, the noise, and the snr
grad_norm = jnp.linalg.norm(model_output)
noise_norm = jnp.linalg.norm(noise)
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * jnp.ones(sample.shape[0])
# compute corrected sample: model_output term and noise term
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
if not return_dict:
return (prev_sample, state)
return FlaxSdeVeOutput(prev_sample=prev_sample, state=state)
def __len__(self):
return self.config.num_train_timesteps
...@@ -11,8 +11,43 @@ class FlaxModelMixin(metaclass=DummyObject): ...@@ -11,8 +11,43 @@ class FlaxModelMixin(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxDDIMScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxDDPMScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxKarrasVeScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxPNDMScheduler(metaclass=DummyObject): class FlaxPNDMScheduler(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
...@@ -814,7 +814,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): ...@@ -814,7 +814,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
for i, t in enumerate(scheduler.timesteps): for i, t in enumerate(scheduler.timesteps):
sigma_t = scheduler.sigmas[i] sigma_t = scheduler.sigmas[i]
for _ in range(scheduler.correct_steps): for _ in range(scheduler.config.correct_steps):
with torch.no_grad(): with torch.no_grad():
model_output = model(sample, sigma_t) model_output = model(sample, sigma_t)
sample = scheduler.step_correct(model_output, sample, generator=generator, **kwargs).prev_sample sample = scheduler.step_correct(model_output, sample, generator=generator, **kwargs).prev_sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment