Unverified Commit 240abddf authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Flax] added broadcast_to_shape_from_left helper and Scheduler tests (#864)



* added broadcast_to_shape_from_left helper

* initial tests

* fixed pndm tests

* shape required for pndm

* added require_flax

* fix style

* fix more imports
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 38ae5a25
...@@ -34,7 +34,7 @@ if is_flax_available(): ...@@ -34,7 +34,7 @@ if is_flax_available():
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
else: else:
from ..utils.dummy_flax_objects import * # noqa F403 from ..utils.dummy_flax_objects import * # noqa F403
......
...@@ -23,7 +23,7 @@ import flax ...@@ -23,7 +23,7 @@ import flax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
...@@ -173,7 +173,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -173,7 +173,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
return variance return variance
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState: def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDIMSchedulerState:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -211,9 +213,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -211,9 +213,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain. timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`): sample (`jnp.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
Returns: Returns:
...@@ -279,13 +278,11 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -279,13 +278,11 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray: ) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_alpha_prod = sqrt_alpha_prod[:, None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -23,7 +23,7 @@ import jax.numpy as jnp ...@@ -23,7 +23,7 @@ import jax.numpy as jnp
from jax import random from jax import random
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
...@@ -101,6 +101,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -101,6 +101,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -129,11 +133,12 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -129,11 +133,12 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.one = jnp.array(1.0) self.one = jnp.array(1.0)
self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps) def create_state(self):
return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
self.variance_type = variance_type
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState: def set_timesteps(
self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDPMSchedulerState:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -214,7 +219,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -214,7 +219,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
t = timestep t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
else: else:
predicted_variance = None predicted_variance = None
...@@ -267,13 +272,11 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -267,13 +272,11 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray: ) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -87,6 +87,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -87,6 +87,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
A reasonable range is [0.2, 80]. A reasonable range is [0.2, 80].
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -97,10 +101,13 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -97,10 +101,13 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
s_min: float = 0.05, s_min: float = 0.05,
s_max: float = 50, s_max: float = 50,
): ):
self.state = KarrasVeSchedulerState.create() pass
def create_state(self):
return KarrasVeSchedulerState.create()
def set_timesteps( def set_timesteps(
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> KarrasVeSchedulerState: ) -> KarrasVeSchedulerState:
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
......
...@@ -20,7 +20,7 @@ import jax.numpy as jnp ...@@ -20,7 +20,7 @@ import jax.numpy as jnp
from scipy import integrate from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
@flax.struct.dataclass @flax.struct.dataclass
...@@ -63,6 +63,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -63,6 +63,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -85,8 +89,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -85,8 +89,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
def create_state(self):
self.state = LMSDiscreteSchedulerState.create( self.state = LMSDiscreteSchedulerState.create(
num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 num_train_timesteps=self.config.num_train_timesteps,
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
) )
def get_lms_coefficient(self, state, order, t, current_order): def get_lms_coefficient(self, state, order, t, current_order):
...@@ -112,7 +118,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -112,7 +118,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
return integrated_coeff return integrated_coeff
def set_timesteps( def set_timesteps(
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> LMSDiscreteSchedulerState: ) -> LMSDiscreteSchedulerState:
""" """
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -199,8 +205,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -199,8 +205,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray, timesteps: jnp.ndarray,
) -> jnp.ndarray: ) -> jnp.ndarray:
sigma = state.sigmas[timesteps].flatten() sigma = state.sigmas[timesteps].flatten()
while len(sigma.shape) < len(noise.shape): sigma = broadcast_to_shape_from_left(sigma, noise.shape)
sigma = sigma[..., None]
noisy_samples = original_samples + noise * sigma noisy_samples = original_samples + noise * sigma
......
...@@ -23,7 +23,7 @@ import jax ...@@ -23,7 +23,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
...@@ -168,6 +168,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -168,6 +168,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `FlaxPNDMScheduler` state data class instance. the `FlaxPNDMScheduler` state data class instance.
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
shape (`Tuple`):
the shape of the samples to be generated.
""" """
offset = self.config.steps_offset offset = self.config.steps_offset
...@@ -509,13 +511,11 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -509,13 +511,11 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray: ) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -22,7 +22,7 @@ import jax.numpy as jnp ...@@ -22,7 +22,7 @@ import jax.numpy as jnp
from jax import random from jax import random
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
@flax.struct.dataclass @flax.struct.dataclass
...@@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
correct_steps (`int`): number of correction steps performed on a produced sample. correct_steps (`int`): number of correction steps performed on a produced sample.
""" """
@property
def has_state(self):
return True
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -90,12 +94,20 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -90,12 +94,20 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
sampling_eps: float = 1e-5, sampling_eps: float = 1e-5,
correct_steps: int = 1, correct_steps: int = 1,
): ):
state = ScoreSdeVeSchedulerState.create() pass
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps) def create_state(self):
state = ScoreSdeVeSchedulerState.create()
return self.set_sigmas(
state,
self.config.num_train_timesteps,
self.config.sigma_min,
self.config.sigma_max,
self.config.sampling_eps,
)
def set_timesteps( def set_timesteps(
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
) -> ScoreSdeVeSchedulerState: ) -> ScoreSdeVeSchedulerState:
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -193,8 +205,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -193,8 +205,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods # also equation 47 shows the analog from SDE models to ancestral sampling methods
diffusion = diffusion.flatten() diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape): diffusion = broadcast_to_shape_from_left(diffusion, sample.shape)
diffusion = diffusion[:, None]
drift = drift - diffusion**2 * model_output drift = drift - diffusion**2 * model_output
# equation 6: sample noise for the diffusion term of # equation 6: sample noise for the diffusion term of
...@@ -252,8 +263,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -252,8 +263,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
# compute corrected sample: model_output term and noise term # compute corrected sample: model_output term and noise term
step_size = step_size.flatten() step_size = step_size.flatten()
while len(step_size.shape) < len(sample.shape): step_size = broadcast_to_shape_from_left(step_size, sample.shape)
step_size = step_size[:, None]
prev_sample_mean = sample + step_size * model_output prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple
import jax.numpy as jnp import jax.numpy as jnp
...@@ -41,3 +42,8 @@ class FlaxSchedulerMixin: ...@@ -41,3 +42,8 @@ class FlaxSchedulerMixin:
""" """
config_name = SCHEDULER_CONFIG_NAME config_name = SCHEDULER_CONFIG_NAME
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
assert len(shape) >= x.ndim
return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment