Unverified Commit b1182bcf authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Flax] fix Flax scheduler (#564)

* remove match_shape

* ported fixes from #479 to flax

* remove unused argument

* typo

* remove warnings
parent 0424615a
...@@ -96,7 +96,13 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -96,7 +96,13 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`): set_alpha_to_one (`bool`, default `True`):
if alpha for final step is 1 or the final alpha of the "non-previous" one. each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
@register_to_config @register_to_config
...@@ -109,6 +115,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -109,6 +115,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[jnp.ndarray] = None, trained_betas: Optional[jnp.ndarray] = None,
clip_sample: bool = True, clip_sample: bool = True,
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = jnp.asarray(trained_betas) self.betas = jnp.asarray(trained_betas)
...@@ -144,9 +151,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -144,9 +151,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def set_timesteps( def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState:
self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0
) -> DDIMSchedulerState:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -155,9 +160,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -155,9 +160,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
the `FlaxDDIMScheduler` state data class instance. the `FlaxDDIMScheduler` state data class instance.
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
""" """
offset = self.config.steps_offset
step_ratio = self.config.num_train_timesteps // num_inference_steps step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3 # casting to int to avoid issues when num_inference_step is power of 3
...@@ -263,9 +268,14 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -263,9 +268,14 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray, timesteps: jnp.ndarray,
) -> jnp.ndarray: ) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_alpha_prod = sqrt_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) sqrt_alpha_prod = sqrt_alpha_prod[:, None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -266,9 +266,14 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -266,9 +266,14 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray, timesteps: jnp.ndarray,
) -> jnp.ndarray: ) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -198,8 +198,11 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -198,8 +198,11 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
noise: jnp.ndarray, noise: jnp.ndarray,
timesteps: jnp.ndarray, timesteps: jnp.ndarray,
) -> jnp.ndarray: ) -> jnp.ndarray:
sigmas = self.match_shape(state.sigmas[timesteps], noise) sigma = state.sigmas[timesteps].flatten()
noisy_samples = original_samples + noise * sigmas while len(sigma.shape) < len(noise.shape):
sigma = sigma[..., None]
noisy_samples = original_samples + noise * sigma
return noisy_samples return noisy_samples
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -59,7 +59,6 @@ class PNDMSchedulerState: ...@@ -59,7 +59,6 @@ class PNDMSchedulerState:
# setable values # setable values
_timesteps: jnp.ndarray _timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None num_inference_steps: Optional[int] = None
_offset: int = 0
prk_timesteps: Optional[jnp.ndarray] = None prk_timesteps: Optional[jnp.ndarray] = None
plms_timesteps: Optional[jnp.ndarray] = None plms_timesteps: Optional[jnp.ndarray] = None
timesteps: Optional[jnp.ndarray] = None timesteps: Optional[jnp.ndarray] = None
...@@ -104,6 +103,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -104,6 +103,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps (`bool`): skip_prk_steps (`bool`):
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
before plms steps; defaults to `False`. before plms steps; defaults to `False`.
set_alpha_to_one (`bool`, default `False`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
""" """
@register_to_config @register_to_config
...@@ -115,6 +122,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -115,6 +122,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear", beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None, trained_betas: Optional[jnp.ndarray] = None,
skip_prk_steps: bool = False, skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = jnp.asarray(trained_betas) self.betas = jnp.asarray(trained_betas)
...@@ -132,6 +141,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -132,6 +141,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# For now we only support F-PNDM, i.e. the runge-kutta method # For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2. # mainly at formula (9), (12), (13) and the Algorithm 2.
...@@ -139,9 +150,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -139,9 +150,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)
def set_timesteps( def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0
) -> PNDMSchedulerState:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -150,16 +159,15 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -150,16 +159,15 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
the `FlaxPNDMScheduler` state data class instance. the `FlaxPNDMScheduler` state data class instance.
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
""" """
offset = self.config.steps_offset
step_ratio = self.config.num_train_timesteps // num_inference_steps step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio # creates integer timesteps by multiplying by ratio
# rounding to avoid issues when num_inference_step is power of 3 # rounding to avoid issues when num_inference_step is power of 3
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset
_timesteps = _timesteps + offset
state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps) state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps)
if self.config.skip_prk_steps: if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to # for some models like stable diffusion the prk steps can/should be skipped to
...@@ -254,7 +262,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -254,7 +262,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
) )
diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1]) prev_timestep = timestep - diff_to_prev
timestep = state.prk_timesteps[state.counter // 4 * 4] timestep = state.prk_timesteps[state.counter // 4 * 4]
if state.counter % 4 == 0: if state.counter % 4 == 0:
...@@ -274,7 +282,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -274,7 +282,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
# cur_sample should not be `None` # cur_sample should not be `None`
cur_sample = state.cur_sample if state.cur_sample is not None else sample cur_sample = state.cur_sample if state.cur_sample is not None else sample
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state) prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1) state = state.replace(counter=state.counter + 1)
if not return_dict: if not return_dict:
...@@ -320,7 +328,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -320,7 +328,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information." "for more information."
) )
prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
if state.counter != 1: if state.counter != 1:
state = state.replace(ets=state.ets.append(model_output)) state = state.replace(ets=state.ets.append(model_output))
...@@ -344,7 +352,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -344,7 +352,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4] 55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
) )
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state) prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1) state = state.replace(counter=state.counter + 1)
if not return_dict: if not return_dict:
...@@ -352,7 +360,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -352,7 +360,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state): def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9) # this function computes x_(t−δ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation # Note that x_t needs to be added to both sides of the equation
...@@ -365,8 +373,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -365,8 +373,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t # sample -> x_t
# model_output -> e_θ(x_t, t) # model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ) # prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset] alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
...@@ -395,9 +403,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -395,9 +403,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray, timesteps: jnp.ndarray,
) -> jnp.ndarray: ) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
......
...@@ -192,14 +192,17 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -192,14 +192,17 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods # also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * model_output diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape):
diffusion = diffusion[:, None]
drift = drift - diffusion**2 * model_output
# equation 6: sample noise for the diffusion term of # equation 6: sample noise for the diffusion term of
key = random.split(key, num=1) key = random.split(key, num=1)
noise = random.normal(key=key, shape=sample.shape) noise = random.normal(key=key, shape=sample.shape)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise? # TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
if not return_dict: if not return_dict:
return (prev_sample, prev_sample_mean, state) return (prev_sample, prev_sample_mean, state)
...@@ -248,8 +251,11 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -248,8 +251,11 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
step_size = step_size * jnp.ones(sample.shape[0]) step_size = step_size * jnp.ones(sample.shape[0])
# compute corrected sample: model_output term and noise term # compute corrected sample: model_output term and noise term
prev_sample_mean = sample + step_size[:, None, None, None] * model_output step_size = step_size.flatten()
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise while len(step_size.shape) < len(sample.shape):
step_size = step_size[:, None]
prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
if not return_dict: if not return_dict:
return (prev_sample, state) return (prev_sample, state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment