Unverified Commit 155d272c authored by bachr's avatar bachr Committed by GitHub
Browse files

Update FlaxLMSDiscreteScheduler (#1474)

- Add the missing `scale_model_input` method to `FlaxLMSDiscreteScheduler`
- Use `jnp.append` for appending to `state.derivatives`
- Use `jnp.delete` to pop from `state.derivatives`
parent 2b30b109
......@@ -102,6 +102,28 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
)
def scale_model_input(
self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int
) -> jnp.ndarray:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
Args:
state (`LMSDiscreteSchedulerState`):
the `FlaxLMSDiscreteScheduler` state data class instance.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
timestep (`int`):
current discrete timestep in the diffusion chain.
Returns:
`jnp.ndarray`: scaled input sample
"""
step_index, = jnp.where(scheduler_state.timesteps == timestep, size=1)
sigma = scheduler_state.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def get_lms_coefficient(self, state, order, t, current_order):
"""
Compute a linear multistep coefficient.
......@@ -186,9 +208,9 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
state = state.replace(derivatives=state.derivatives.append(derivative))
state = state.replace(derivatives=jnp.append(state.derivatives, derivative))
if len(state.derivatives) > order:
state = state.replace(derivatives=state.derivatives.pop(0))
state = state.replace(derivatives=jnp.delete(state.derivatives, 0))
# 3. Compute linear multistep coefficients
order = min(timestep + 1, order)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment