Unverified Commit 240abddf authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Flax] added broadcast_to_shape_from_left helper and Scheduler tests (#864)



* added broadcast_to_shape_from_left helper

* initial tests

* fixed pndm tests

* shape required for pndm

* added require_flax

* fix style

* fix more imports
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 38ae5a25
......@@ -34,7 +34,7 @@ if is_flax_available():
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
else:
from ..utils.dummy_flax_objects import * # noqa F403
......
......@@ -23,7 +23,7 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
......@@ -173,7 +173,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
return variance
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState:
def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDIMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
......@@ -211,9 +213,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
Returns:
......@@ -279,13 +278,11 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[:, None]
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
......
......@@ -23,7 +23,7 @@ import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
......@@ -101,6 +101,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
@property
def has_state(self):
return True
@register_to_config
def __init__(
self,
......@@ -129,11 +133,12 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.one = jnp.array(1.0)
self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps)
self.variance_type = variance_type
def create_state(self):
return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState:
def set_timesteps(
self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDPMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
......@@ -214,7 +219,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
else:
predicted_variance = None
......@@ -267,13 +272,11 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
......
......@@ -87,6 +87,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
A reasonable range is [0.2, 80].
"""
@property
def has_state(self):
return True
@register_to_config
def __init__(
self,
......@@ -97,10 +101,13 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
s_min: float = 0.05,
s_max: float = 50,
):
self.state = KarrasVeSchedulerState.create()
pass
def create_state(self):
return KarrasVeSchedulerState.create()
def set_timesteps(
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> KarrasVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
......
......@@ -20,7 +20,7 @@ import jax.numpy as jnp
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
@flax.struct.dataclass
......@@ -63,6 +63,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
"""
@property
def has_state(self):
return True
@register_to_config
def __init__(
self,
......@@ -85,8 +89,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
def create_state(self):
self.state = LMSDiscreteSchedulerState.create(
num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
num_train_timesteps=self.config.num_train_timesteps,
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
)
def get_lms_coefficient(self, state, order, t, current_order):
......@@ -112,7 +118,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
return integrated_coeff
def set_timesteps(
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> LMSDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
......@@ -199,8 +205,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sigma = state.sigmas[timesteps].flatten()
while len(sigma.shape) < len(noise.shape):
sigma = sigma[..., None]
sigma = broadcast_to_shape_from_left(sigma, noise.shape)
noisy_samples = original_samples + noise * sigma
......
......@@ -23,7 +23,7 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
......@@ -168,6 +168,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `FlaxPNDMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
shape (`Tuple`):
the shape of the samples to be generated.
"""
offset = self.config.steps_offset
......@@ -509,13 +511,11 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
......
......@@ -22,7 +22,7 @@ import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
@flax.struct.dataclass
......@@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
correct_steps (`int`): number of correction steps performed on a produced sample.
"""
@property
def has_state(self):
return True
@register_to_config
def __init__(
self,
......@@ -90,12 +94,20 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
sampling_eps: float = 1e-5,
correct_steps: int = 1,
):
state = ScoreSdeVeSchedulerState.create()
pass
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def create_state(self):
state = ScoreSdeVeSchedulerState.create()
return self.set_sigmas(
state,
self.config.num_train_timesteps,
self.config.sigma_min,
self.config.sigma_max,
self.config.sampling_eps,
)
def set_timesteps(
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
) -> ScoreSdeVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
......@@ -193,8 +205,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape):
diffusion = diffusion[:, None]
diffusion = broadcast_to_shape_from_left(diffusion, sample.shape)
drift = drift - diffusion**2 * model_output
# equation 6: sample noise for the diffusion term of
......@@ -252,8 +263,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
# compute corrected sample: model_output term and noise term
step_size = step_size.flatten()
while len(step_size.shape) < len(sample.shape):
step_size = step_size[:, None]
step_size = broadcast_to_shape_from_left(step_size, sample.shape)
prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple
import jax.numpy as jnp
......@@ -41,3 +42,8 @@ class FlaxSchedulerMixin:
"""
config_name = SCHEDULER_CONFIG_NAME
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
assert len(shape) >= x.ndim
return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment