Unverified Commit b9e921fe authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

added initial v-pred support to DPM-solver (#1421)

* added initial v-pred support to DPM-solver

* fix sign

* added v_prediction to flax

* fixed typo
parent 76845183
...@@ -88,8 +88,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -88,8 +88,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling. sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, default `epsilon`): prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`. indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
`v-prediction` is not supported for this scheduler. or `v-prediction`.
thresholding (`bool`, default `False`): thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
...@@ -212,7 +212,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -212,7 +212,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
discretize an integral of the data prediction model. So we need to first convert the model output to the discretize an integral of the data prediction model. So we need to first convert the model output to the
corresponding type to match the algorithm. corresponding type to match the algorithm.
...@@ -235,10 +235,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -235,10 +235,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
x0_pred = model_output x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" for the DPMSolverMultistepScheduler." " `v_prediction` for the DPMSolverMultistepScheduler."
) )
if self.config.thresholding: if self.config.thresholding:
...@@ -260,10 +263,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -260,10 +263,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon return epsilon
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" for the DPMSolverMultistepScheduler." " `v_prediction` for the DPMSolverMultistepScheduler."
) )
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
......
...@@ -120,8 +120,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -120,8 +120,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling. sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, default `epsilon`): prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`. indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
`v-prediction` is not supported for this scheduler. or `v-prediction`.
thresholding (`bool`, default `False`): thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
...@@ -252,7 +252,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -252,7 +252,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
""" """
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
discretize an integral of the data prediction model. So we need to first convert the model output to the discretize an integral of the data prediction model. So we need to first convert the model output to the
corresponding type to match the algorithm. corresponding type to match the algorithm.
...@@ -275,10 +275,13 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -275,10 +275,13 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
x0_pred = (sample - sigma_t * model_output) / alpha_t x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample": elif self.config.prediction_type == "sample":
x0_pred = model_output x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
" for the FlaxDPMSolverMultistepScheduler." " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
) )
if self.config.thresholding: if self.config.thresholding:
...@@ -299,10 +302,14 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -299,10 +302,14 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon return epsilon
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
" for the FlaxDPMSolverMultistepScheduler." " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
) )
def dpm_solver_first_order_update( def dpm_solver_first_order_update(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment