Commit 7222a8ea authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent 155d272c
......@@ -102,9 +102,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
)
def scale_model_input(
self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int
) -> jnp.ndarray:
def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
......@@ -119,7 +117,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
Returns:
`jnp.ndarray`: scaled input sample
"""
step_index, = jnp.where(scheduler_state.timesteps == timestep, size=1)
(step_index,) = jnp.where(scheduler_state.timesteps == timestep, size=1)
sigma = scheduler_state.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment