Unverified Commit 5f0df177 authored by hlky's avatar hlky Committed by GitHub
Browse files

Refactor SchedulerOutput and add pred_original_sample in `DPMSolverSDE`,...


Refactor SchedulerOutput and add pred_original_sample in `DPMSolverSDE`, `Heun`, `KDPM2Ancestral` and `KDPM2` (#9650)

Refactor SchedulerOutput and add pred_original_sample
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 957e5cab
......@@ -13,6 +13,7 @@
# limitations under the License.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
......@@ -20,14 +21,33 @@ import torch
import torchsde
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
if is_scipy_available():
import scipy.stats
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE
class DPMSolverSDESchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
......@@ -510,7 +530,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
s_noise: float = 1.0,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
......@@ -522,15 +542,16 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor` or `np.ndarray`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
tuple.
s_noise (`float`, *optional*, defaults to 1.0):
Scaling factor for noise added to the sample.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
[`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if self.step_index is None:
self._init_step_index(timestep)
......@@ -610,9 +631,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return SchedulerOutput(prev_sample=prev_sample)
return DPMSolverSDESchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
......
......@@ -13,20 +13,40 @@
# limitations under the License.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
if is_scipy_available():
import scipy.stats
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->HeunDiscrete
class HeunDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
......@@ -455,7 +475,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[HeunDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
......@@ -468,12 +488,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Whether or not to return a [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
[`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if self.step_index is None:
self._init_step_index(timestep)
......@@ -544,9 +565,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return SchedulerOutput(prev_sample=prev_sample)
return HeunDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
......
......@@ -13,21 +13,41 @@
# limitations under the License.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from ..utils import BaseOutput, is_scipy_available
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
if is_scipy_available():
import scipy.stats
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2AncestralDiscrete
class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
......@@ -459,7 +479,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.Tensor, np.ndarray],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[KDPM2AncestralDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
......@@ -474,12 +494,14 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Whether or not to return a
[`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
[`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if self.step_index is None:
self._init_step_index(timestep)
......@@ -548,9 +570,14 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return SchedulerOutput(prev_sample=prev_sample)
return KDPM2AncestralDiscreteSchedulerOutput(
prev_sample=prev_sample, pred_original_sample=pred_original_sample
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
......
......@@ -13,20 +13,40 @@
# limitations under the License.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
if is_scipy_available():
import scipy.stats
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2Discrete
class KDPM2DiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
......@@ -443,7 +463,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[KDPM2DiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
......@@ -456,12 +476,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Whether or not to return a [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
[`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if self.step_index is None:
self._init_step_index(timestep)
......@@ -523,9 +544,12 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt
if not return_dict:
return (prev_sample,)
return (
prev_sample,
pred_original_sample,
)
return SchedulerOutput(prev_sample=prev_sample)
return KDPM2DiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment