Unverified Commit b1182bcf authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Flax] fix Flax scheduler (#564)

* remove match_shape

* ported fixes from #479 to flax

* remove unused argument

* typo

* remove warnings
parent 0424615a
......@@ -96,7 +96,13 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
if alpha for final step is 1 or the final alpha of the "non-previous" one.
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""
@register_to_config
......@@ -109,6 +115,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[jnp.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
......@@ -144,9 +151,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
return variance
def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0
) -> DDIMSchedulerState:
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
......@@ -155,9 +160,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
the `FlaxDDIMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
offset = self.config.steps_offset
step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
......@@ -263,9 +268,14 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[:, None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
......
......@@ -266,9 +266,14 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
......
......@@ -198,8 +198,11 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sigmas = self.match_shape(state.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas
sigma = state.sigmas[timesteps].flatten()
while len(sigma.shape) < len(noise.shape):
sigma = sigma[..., None]
noisy_samples = original_samples + noise * sigma
return noisy_samples
......
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -59,7 +59,6 @@ class PNDMSchedulerState:
# setable values
_timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
_offset: int = 0
prk_timesteps: Optional[jnp.ndarray] = None
plms_timesteps: Optional[jnp.ndarray] = None
timesteps: Optional[jnp.ndarray] = None
......@@ -104,6 +103,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps (`bool`):
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
before plms steps; defaults to `False`.
set_alpha_to_one (`bool`, default `False`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""
@register_to_config
......@@ -115,6 +122,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
......@@ -132,6 +141,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2.
......@@ -139,9 +150,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)
def set_timesteps(
self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0
) -> PNDMSchedulerState:
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
......@@ -150,16 +159,15 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
the `FlaxPNDMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
offset = self.config.steps_offset
step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# rounding to avoid issues when num_inference_step is power of 3
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
_timesteps = _timesteps + offset
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset
state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps)
state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps)
if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
......@@ -254,7 +262,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
)
diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1])
prev_timestep = timestep - diff_to_prev
timestep = state.prk_timesteps[state.counter // 4 * 4]
if state.counter % 4 == 0:
......@@ -274,7 +282,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
# cur_sample should not be `None`
cur_sample = state.cur_sample if state.cur_sample is not None else sample
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state)
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)
if not return_dict:
......@@ -320,7 +328,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information."
)
prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
if state.counter != 1:
state = state.replace(ets=state.ets.append(model_output))
......@@ -344,7 +352,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
)
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state)
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)
if not return_dict:
......@@ -352,7 +360,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state):
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
......@@ -365,8 +373,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset]
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
......@@ -395,9 +403,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
......
......@@ -192,14 +192,17 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape):
diffusion = diffusion[:, None]
drift = drift - diffusion**2 * model_output
# equation 6: sample noise for the diffusion term of
key = random.split(key, num=1)
noise = random.normal(key=key, shape=sample.shape)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
if not return_dict:
return (prev_sample, prev_sample_mean, state)
......@@ -248,8 +251,11 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
step_size = step_size * jnp.ones(sample.shape[0])
# compute corrected sample: model_output term and noise term
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
step_size = step_size.flatten()
while len(step_size.shape) < len(sample.shape):
step_size = step_size[:, None]
prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
if not return_dict:
return (prev_sample, state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment