Commit 7222a8ea authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make style

parent 155d272c
...@@ -102,9 +102,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -102,9 +102,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
) )
def scale_model_input( def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int
) -> jnp.ndarray:
""" """
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
...@@ -119,7 +117,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -119,7 +117,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
Returns: Returns:
`jnp.ndarray`: scaled input sample `jnp.ndarray`: scaled input sample
""" """
step_index, = jnp.where(scheduler_state.timesteps == timestep, size=1) (step_index,) = jnp.where(scheduler_state.timesteps == timestep, size=1)
sigma = scheduler_state.sigmas[step_index] sigma = scheduler_state.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5) sample = sample / ((sigma**2 + 1) ** 0.5)
return sample return sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment