Unverified Commit 91db8189 authored by Jonathan Whitaker's avatar Jonathan Whitaker Committed by GitHub
Browse files

Adding pred_original_sample to SchedulerOutput for some samplers (#614)

* Adding pred_original_sample to SchedulerOutput of DDPMScheduler, DDIMScheduler, LMSDiscreteScheduler, KarrasVeScheduler step methods so we can access the predicted denoised outputs

* Gave DDPMScheduler, DDIMScheduler and LMSDiscreteScheduler their own output dataclasses so the default SchedulerOutput in scheduling_utils does not need pred_original_sample as an optional extra

* Reordered library imports to follow standard

* didnt get import order quite right apparently

* Forgot to change name of LMSDiscreteSchedulerOutput

* Aha, needed some extra libs for make style to fully work
parent f149d037
......@@ -17,13 +17,33 @@
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class DDIMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -179,7 +199,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
use_clipped_model_output: bool = False,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
......@@ -192,11 +212,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
......@@ -261,7 +281,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def add_noise(
self,
......
......@@ -15,13 +15,33 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class DDPMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -177,7 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
predict_epsilon=True,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[DDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
......@@ -190,11 +210,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
......@@ -242,7 +262,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (pred_prev_sample,)
return SchedulerOutput(prev_sample=pred_prev_sample)
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
def add_noise(
self,
......
......@@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput):
denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivative of predicted original image sample (x_0).
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
derivative: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
......@@ -153,7 +157,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
Returns:
......@@ -170,7 +174,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)
def step_correct(
self,
......@@ -192,7 +198,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
......@@ -205,7 +211,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)
def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
......@@ -20,7 +21,26 @@ import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class LMSDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
......@@ -133,7 +153,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray],
order: int = 4,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
......@@ -144,12 +164,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
sigma = self.sigmas[timestep]
......@@ -175,7 +195,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def add_noise(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment