Unverified Commit 240abddf authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Flax] added broadcast_to_shape_from_left helper and Scheduler tests (#864)



* added broadcast_to_shape_from_left helper

* initial tests

* fixed pndm tests

* shape required for pndm

* added require_flax

* fix style

* fix more imports
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 38ae5a25
...@@ -34,7 +34,7 @@ if is_flax_available(): ...@@ -34,7 +34,7 @@ if is_flax_available():
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
else: else:
from ..utils.dummy_flax_objects import * # noqa F403 from ..utils.dummy_flax_objects import * # noqa F403
......
...@@ -23,7 +23,7 @@ import flax ...@@ -23,7 +23,7 @@ import flax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
...@@ -173,7 +173,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -173,7 +173,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
return variance return variance
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState: def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDIMSchedulerState:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -211,9 +213,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -211,9 +213,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
Returns: Returns:
...@@ -279,13 +278,11 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -279,13 +278,11 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray: ) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_alpha_prod = sqrt_alpha_prod[:, None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -23,7 +23,7 @@ import jax.numpy as jnp ...@@ -23,7 +23,7 @@ import jax.numpy as jnp
from jax import random from jax import random
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
...@@ -101,6 +101,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -101,6 +101,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -129,11 +133,12 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -129,11 +133,12 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.one = jnp.array(1.0) self.one = jnp.array(1.0)
self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps) def create_state(self):
return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
self.variance_type = variance_type
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState: def set_timesteps(
self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDPMSchedulerState:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -214,7 +219,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -214,7 +219,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
t = timestep t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
else: else:
predicted_variance = None predicted_variance = None
...@@ -267,13 +272,11 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -267,13 +272,11 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray: ) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -87,6 +87,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -87,6 +87,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
A reasonable range is [0.2, 80]. A reasonable range is [0.2, 80].
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -97,10 +101,13 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -97,10 +101,13 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
s_min: float = 0.05, s_min: float = 0.05,
s_max: float = 50, s_max: float = 50,
): ):
self.state = KarrasVeSchedulerState.create() pass
def create_state(self):
return KarrasVeSchedulerState.create()
def set_timesteps( def set_timesteps(
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> KarrasVeSchedulerState: ) -> KarrasVeSchedulerState:
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
......
...@@ -20,7 +20,7 @@ import jax.numpy as jnp ...@@ -20,7 +20,7 @@ import jax.numpy as jnp
from scipy import integrate from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
@flax.struct.dataclass @flax.struct.dataclass
...@@ -63,6 +63,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -63,6 +63,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -85,8 +89,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -85,8 +89,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
def create_state(self):
self.state = LMSDiscreteSchedulerState.create( self.state = LMSDiscreteSchedulerState.create(
num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 num_train_timesteps=self.config.num_train_timesteps,
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
) )
def get_lms_coefficient(self, state, order, t, current_order): def get_lms_coefficient(self, state, order, t, current_order):
...@@ -112,7 +118,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -112,7 +118,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
return integrated_coeff return integrated_coeff
def set_timesteps( def set_timesteps(
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> LMSDiscreteSchedulerState: ) -> LMSDiscreteSchedulerState:
""" """
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -199,8 +205,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -199,8 +205,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray, timesteps: jnp.ndarray,
) -> jnp.ndarray: ) -> jnp.ndarray:
sigma = state.sigmas[timesteps].flatten() sigma = state.sigmas[timesteps].flatten()
while len(sigma.shape) < len(noise.shape): sigma = broadcast_to_shape_from_left(sigma, noise.shape)
sigma = sigma[..., None]
noisy_samples = original_samples + noise * sigma noisy_samples = original_samples + noise * sigma
......
...@@ -23,7 +23,7 @@ import jax ...@@ -23,7 +23,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
...@@ -168,6 +168,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -168,6 +168,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `FlaxPNDMScheduler` state data class instance. the `FlaxPNDMScheduler` state data class instance.
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
shape (`Tuple`):
the shape of the samples to be generated.
""" """
offset = self.config.steps_offset offset = self.config.steps_offset
...@@ -509,13 +511,11 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -509,13 +511,11 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray: ) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -22,7 +22,7 @@ import jax.numpy as jnp ...@@ -22,7 +22,7 @@ import jax.numpy as jnp
from jax import random from jax import random
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
@flax.struct.dataclass @flax.struct.dataclass
...@@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
correct_steps (`int`): number of correction steps performed on a produced sample. correct_steps (`int`): number of correction steps performed on a produced sample.
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -90,12 +94,20 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -90,12 +94,20 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
sampling_eps: float = 1e-5, sampling_eps: float = 1e-5,
correct_steps: int = 1, correct_steps: int = 1,
): ):
state = ScoreSdeVeSchedulerState.create() pass
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps) def create_state(self):
state = ScoreSdeVeSchedulerState.create()
return self.set_sigmas(
state,
self.config.num_train_timesteps,
self.config.sigma_min,
self.config.sigma_max,
self.config.sampling_eps,
)
def set_timesteps( def set_timesteps(
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
) -> ScoreSdeVeSchedulerState: ) -> ScoreSdeVeSchedulerState:
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -193,8 +205,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -193,8 +205,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods # also equation 47 shows the analog from SDE models to ancestral sampling methods
diffusion = diffusion.flatten() diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape): diffusion = broadcast_to_shape_from_left(diffusion, sample.shape)
diffusion = diffusion[:, None]
drift = drift - diffusion**2 * model_output drift = drift - diffusion**2 * model_output
# equation 6: sample noise for the diffusion term of # equation 6: sample noise for the diffusion term of
...@@ -252,8 +263,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -252,8 +263,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
# compute corrected sample: model_output term and noise term # compute corrected sample: model_output term and noise term
step_size = step_size.flatten() step_size = step_size.flatten()
while len(step_size.shape) < len(sample.shape): step_size = broadcast_to_shape_from_left(step_size, sample.shape)
step_size = step_size[:, None]
prev_sample_mean = sample + step_size * model_output prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple
import jax.numpy as jnp import jax.numpy as jnp
...@@ -41,3 +42,8 @@ class FlaxSchedulerMixin: ...@@ -41,3 +42,8 @@ class FlaxSchedulerMixin:
""" """
config_name = SCHEDULER_CONFIG_NAME config_name = SCHEDULER_CONFIG_NAME
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
assert len(shape) >= x.ndim
return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from typing import Dict, List, Tuple
from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax
if is_flax_available():
import jax.numpy as jnp
from jax import random
@require_flax
class FlaxSchedulerCommonTest(unittest.TestCase):
scheduler_classes = ()
forward_default_kwargs = ()
@property
def dummy_sample(self):
batch_size = 4
num_channels = 3
height = 8
width = 8
key1, key2 = random.split(random.PRNGKey(0))
sample = random.uniform(key1, (batch_size, num_channels, height, width))
return sample, key2
@property
def dummy_sample_deter(self):
batch_size = 4
num_channels = 3
height = 8
width = 8
num_elems = batch_size * num_channels * height * width
sample = jnp.arange(num_elems)
sample = sample.reshape(num_channels, height, width, batch_size)
sample = sample / num_elems
return jnp.transpose(sample, (3, 0, 1, 2))
def get_scheduler_config(self):
raise NotImplementedError
def dummy_model(self):
def model(sample, t, *args):
return sample * t / (t + 1)
return model
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, key = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, key = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, key = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, key = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample
output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
return t.at[t != t].set(0)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
),
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, key = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs)
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
@require_flax
class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
scheduler_classes = (FlaxDDPMScheduler,)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"variance_type": "fixed_small",
"clip_sample": True,
}
config.update(**kwargs)
return config
def test_timesteps(self):
for timesteps in [1, 5, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_variance_type(self):
for variance in ["fixed_small", "fixed_large", "other"]:
self.check_over_configs(variance_type=variance)
def test_clip_sample(self):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_time_indices(self):
for t in [0, 500, 999]:
self.check_over_forward(time_step=t)
def test_variance(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
assert jnp.sum(jnp.abs(scheduler._get_variance(0) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
num_trained_timesteps = len(scheduler)
model = self.dummy_model()
sample = self.dummy_sample_deter
key1, key2 = random.split(random.PRNGKey(0))
for t in reversed(range(num_trained_timesteps)):
# 1. predict noise residual
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
output = scheduler.step(state, residual, t, sample, key1)
pred_prev_sample = output.prev_sample
state = output.state
key1, key2 = random.split(key2)
# if t > 0:
# noise = self.dummy_sample_deter
# variance = scheduler.get_variance(t) ** (0.5) * noise
#
# sample = pred_prev_sample + variance
sample = pred_prev_sample
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
assert abs(result_sum - 255.1113) < 1e-2
assert abs(result_mean - 0.332176) < 1e-3
@require_flax
class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
scheduler_classes = (FlaxDDIMScheduler,)
forward_default_kwargs = (("num_inference_steps", 50),)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
}
config.update(**kwargs)
return config
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
key1, key2 = random.split(random.PRNGKey(0))
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
state = scheduler.set_timesteps(state, num_inference_steps)
for t in state.timesteps:
residual = model(sample, t)
output = scheduler.step(state, residual, t, sample)
sample = output.prev_sample
state = output.state
key1, key2 = random.split(key2)
return sample
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, _ = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, _ = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample, _ = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
return t.at[t != t].set(0)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
),
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, _ = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, _ = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [100, 500, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_steps_offset(self):
for steps_offset in [0, 1]:
self.check_over_configs(steps_offset=steps_offset)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
state = scheduler.set_timesteps(state, 5)
assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all()
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_time_indices(self):
for t in [1, 10, 49]:
self.check_over_forward(time_step=t)
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_variance(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(420, 400, state.alphas_cumprod) - 0.14771)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(980, 960, state.alphas_cumprod) - 0.32460)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(0, 0, state.alphas_cumprod) - 0.0)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(487, 486, state.alphas_cumprod) - 0.00979)) < 1e-5
assert jnp.sum(jnp.abs(scheduler._get_variance(999, 998, state.alphas_cumprod) - 0.02)) < 1e-5
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
assert abs(result_sum - 172.0067) < 1e-2
assert abs(result_mean - 0.223967) < 1e-3
def test_full_loop_with_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
assert abs(result_sum - 149.8295) < 1e-2
assert abs(result_mean - 0.1951) < 1e-3
def test_full_loop_with_no_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
assert abs(result_sum - 149.0784) < 1e-2
assert abs(result_mean - 0.1941) < 1e-3
@require_flax
class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
scheduler_classes = (FlaxPNDMScheduler,)
forward_default_kwargs = (("num_inference_steps", 50),)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
}
config.update(**kwargs)
return config
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample, _ = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
# copy over dummy past residuals
state = state.replace(ets=dummy_past_residuals[:])
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
# copy over dummy past residuals
new_state = new_state.replace(ets=dummy_past_residuals[:])
(prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
(new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical"
output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
pass
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
return t.at[t != t].set(0)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
),
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, _ = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample, _ = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
# copy over dummy past residuals (must be after setting timesteps)
scheduler.ets = dummy_past_residuals[:]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
# copy over dummy past residual (must be after setting timesteps)
new_state.replace(ets=dummy_past_residuals[:])
output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
for i, t in enumerate(state.prk_timesteps):
residual = model(sample, t)
sample, state = scheduler.step_prk(state, residual, t, sample)
for i, t in enumerate(state.plms_timesteps):
residual = model(sample, t)
sample, state = scheduler.step_plms(state, residual, t, sample)
return sample
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
sample, _ = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
state = state.replace(ets=dummy_past_residuals[:])
output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs)
output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs)
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs)
output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs)
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_steps_offset(self):
for steps_offset in [0, 1]:
self.check_over_configs(steps_offset=steps_offset)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
state = scheduler.set_timesteps(state, 10, shape=())
assert jnp.equal(
state.timesteps,
jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
).all()
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_time_indices(self):
for t in [1, 5, 10]:
self.check_over_forward(time_step=t)
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(num_inference_steps=num_inference_steps)
def test_pow_of_3_inference_steps(self):
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
num_inference_steps = 27
for scheduler_class in self.scheduler_classes:
sample, _ = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
# before power of 3 fix, would error on first step, so we only need to do two
for i, t in enumerate(state.prk_timesteps[:2]):
sample, state = scheduler.step_prk(state, residual, t, sample)
def test_inference_plms_no_past_residuals(self):
with self.assertRaises(ValueError):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
state = scheduler.create_state()
scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
assert abs(result_sum - 198.1318) < 1e-2
assert abs(result_mean - 0.2580) < 1e-3
def test_full_loop_with_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
assert abs(result_sum - 186.9466) < 1e-2
assert abs(result_mean - 0.24342) < 1e-3
def test_full_loop_with_no_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
result_sum = jnp.sum(jnp.abs(sample))
result_mean = jnp.mean(jnp.abs(sample))
assert abs(result_sum - 186.9482) < 1e-2
assert abs(result_mean - 0.2434) < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment