Unverified Commit 91db8189 authored by Jonathan Whitaker's avatar Jonathan Whitaker Committed by GitHub
Browse files

Adding pred_original_sample to SchedulerOutput for some samplers (#614)

* Adding pred_original_sample to SchedulerOutput of DDPMScheduler, DDIMScheduler, LMSDiscreteScheduler, KarrasVeScheduler step methods so we can access the predicted denoised outputs

* Gave DDPMScheduler, DDIMScheduler and LMSDiscreteScheduler their own output dataclasses so the default SchedulerOutput in scheduling_utils does not need pred_original_sample as an optional extra

* Reordered library imports to follow standard

* didnt get import order quite right apparently

* Forgot to change name of LMSDiscreteSchedulerOutput

* Aha, needed some extra libs for make style to fully work
parent f149d037
...@@ -17,13 +17,33 @@ ...@@ -17,13 +17,33 @@
import math import math
import warnings import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class DDIMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...@@ -179,7 +199,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -179,7 +199,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
use_clipped_model_output: bool = False, use_clipped_model_output: bool = False,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[DDIMSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
...@@ -192,11 +212,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -192,11 +212,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
eta (`float`): weight of noise for added noise in diffusion step. eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO use_clipped_model_output (`bool`): TODO
generator: random number generator. generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor. returning a tuple, the first element is the sample tensor.
""" """
...@@ -261,7 +281,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -261,7 +281,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def add_noise( def add_noise(
self, self,
......
...@@ -15,13 +15,33 @@ ...@@ -15,13 +15,33 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class DDPMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...@@ -177,7 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -177,7 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
predict_epsilon=True, predict_epsilon=True,
generator=None, generator=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[DDPMSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
...@@ -190,11 +210,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -190,11 +210,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
predict_epsilon (`bool`): predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon. optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator. generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor. returning a tuple, the first element is the sample tensor.
""" """
...@@ -242,7 +262,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -242,7 +262,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict: if not return_dict:
return (pred_prev_sample,) return (pred_prev_sample,)
return SchedulerOutput(prev_sample=pred_prev_sample) return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
def add_noise( def add_noise(
self, self,
......
...@@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput): ...@@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput):
denoising loop. denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivative of predicted original image sample (x_0). Derivative of predicted original image sample (x_0).
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
""" """
prev_sample: torch.FloatTensor prev_sample: torch.FloatTensor
derivative: torch.FloatTensor derivative: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
class KarrasVeScheduler(SchedulerMixin, ConfigMixin): class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
...@@ -153,7 +157,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -153,7 +157,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sigma_hat (`float`): TODO sigma_hat (`float`): TODO
sigma_prev (`float`): TODO sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
Returns: Returns:
...@@ -170,7 +174,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -170,7 +174,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
if not return_dict: if not return_dict:
return (sample_prev, derivative) return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)
def step_correct( def step_correct(
self, self,
...@@ -192,7 +198,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -192,7 +198,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO derivative (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
Returns: Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
...@@ -205,7 +211,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -205,7 +211,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
if not return_dict: if not return_dict:
return (sample_prev, derivative) return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)
def add_noise(self, original_samples, noise, timesteps): def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError() raise NotImplementedError()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -20,7 +21,26 @@ import torch ...@@ -20,7 +21,26 @@ import torch
from scipy import integrate from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class LMSDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
...@@ -133,7 +153,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -133,7 +153,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
order: int = 4, order: int = 4,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
""" """
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
...@@ -144,12 +164,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -144,12 +164,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.FloatTensor` or `np.ndarray`): sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process. current instance of sample being created by diffusion process.
order: coefficient for multi-step inference. order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
returning a tuple, the first element is the sample tensor. When returning a tuple, the first element is the sample tensor.
""" """
sigma = self.sigmas[timestep] sigma = self.sigmas[timestep]
...@@ -175,7 +195,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -175,7 +195,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample) return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def add_noise( def add_noise(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment