Unverified Commit d9e7857a authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

timestep_spacing for FlaxDPMSolverMultistepScheduler (#5189)

* timestep_spacing for FlaxDPMSolverMultistepScheduler

* Style
parent fd1c54ab
...@@ -135,6 +135,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -135,6 +135,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
lower_order_final (`bool`, default `True`): lower_order_final (`bool`, default `True`):
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation. the `dtype` used for params and computation.
""" """
...@@ -163,6 +166,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -163,6 +166,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
algorithm_type: str = "dpmsolver++", algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint", solver_type: str = "midpoint",
lower_order_final: bool = True, lower_order_final: bool = True,
timestep_spacing: str = "linspace",
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
): ):
self.dtype = dtype self.dtype = dtype
...@@ -210,12 +214,29 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -210,12 +214,29 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
shape (`Tuple`): shape (`Tuple`):
the shape of the samples to be generated. the shape of the samples to be generated.
""" """
last_timestep = self.config.num_train_timesteps
timesteps = ( if self.config.timestep_spacing == "linspace":
jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) timesteps = (
.round()[::-1][:-1] jnp.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].astype(jnp.int32)
.astype(jnp.int32) )
) elif self.config.timestep_spacing == "leading":
step_ratio = last_timestep // (num_inference_steps + 1)
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(jnp.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(jnp.int32)
)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = jnp.arange(last_timestep, 0, -step_ratio).round().copy().astype(jnp.int32)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
# initial running values # initial running values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment