Unverified Commit 5f0df177 authored by hlky's avatar hlky Committed by GitHub
Browse files

Refactor SchedulerOutput and add pred_original_sample in `DPMSolverSDE`,...


Refactor SchedulerOutput and add pred_original_sample in `DPMSolverSDE`, `Heun`, `KDPM2Ancestral` and `KDPM2` (#9650)

Refactor SchedulerOutput and add pred_original_sample
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 957e5cab
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -20,14 +21,33 @@ import torch ...@@ -20,14 +21,33 @@ import torch
import torchsde import torchsde
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available from ..utils import BaseOutput, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
if is_scipy_available(): if is_scipy_available():
import scipy.stats import scipy.stats
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DPMSolverSDE
class DPMSolverSDESchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
class BatchedBrownianTree: class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy.""" """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
...@@ -510,7 +530,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -510,7 +530,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.Tensor, np.ndarray], sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True, return_dict: bool = True,
s_noise: float = 1.0, s_noise: float = 1.0,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[DPMSolverSDESchedulerOutput, Tuple]:
""" """
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
...@@ -522,15 +542,16 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -522,15 +542,16 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain. The current discrete timestep in the diffusion chain.
sample (`torch.Tensor` or `np.ndarray`): sample (`torch.Tensor` or `np.ndarray`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or
tuple.
s_noise (`float`, *optional*, defaults to 1.0): s_noise (`float`, *optional*, defaults to 1.0):
Scaling factor for noise added to the sample. Scaling factor for noise added to the sample.
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a If return_dict is `True`, [`~schedulers.scheduling_dpmsolver_sde.DPMSolverSDESchedulerOutput`] is
tuple is returned where the first element is the sample tensor. returned, otherwise a tuple is returned where the first element is the sample tensor.
""" """
if self.step_index is None: if self.step_index is None:
self._init_step_index(timestep) self._init_step_index(timestep)
...@@ -610,9 +631,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): ...@@ -610,9 +631,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1 self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (
prev_sample,
pred_original_sample,
)
return SchedulerOutput(prev_sample=prev_sample) return DPMSolverSDESchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
......
...@@ -13,20 +13,40 @@ ...@@ -13,20 +13,40 @@
# limitations under the License. # limitations under the License.
import math import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available from ..utils import BaseOutput, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
if is_scipy_available(): if is_scipy_available():
import scipy.stats import scipy.stats
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->HeunDiscrete
class HeunDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
...@@ -455,7 +475,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -455,7 +475,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep: Union[float, torch.Tensor], timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray], sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[HeunDiscreteSchedulerOutput, Tuple]:
""" """
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
...@@ -468,12 +488,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -468,12 +488,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.Tensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or
tuple.
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a If return_dict is `True`, [`~schedulers.scheduling_heun_discrete.HeunDiscreteSchedulerOutput`] is
tuple is returned where the first element is the sample tensor. returned, otherwise a tuple is returned where the first element is the sample tensor.
""" """
if self.step_index is None: if self.step_index is None:
self._init_step_index(timestep) self._init_step_index(timestep)
...@@ -544,9 +565,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -544,9 +565,12 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1 self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (
prev_sample,
pred_original_sample,
)
return SchedulerOutput(prev_sample=prev_sample) return HeunDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
......
...@@ -13,21 +13,41 @@ ...@@ -13,21 +13,41 @@
# limitations under the License. # limitations under the License.
import math import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available from ..utils import BaseOutput, is_scipy_available
from ..utils.torch_utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
if is_scipy_available(): if is_scipy_available():
import scipy.stats import scipy.stats
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2AncestralDiscrete
class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
...@@ -459,7 +479,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -459,7 +479,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.Tensor, np.ndarray], sample: Union[torch.Tensor, np.ndarray],
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[KDPM2AncestralDiscreteSchedulerOutput, Tuple]:
""" """
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
...@@ -474,12 +494,14 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -474,12 +494,14 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a
[`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or tuple.
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddim.SchedulerOutput`] is returned, otherwise a If return_dict is `True`,
tuple is returned where the first element is the sample tensor. [`~schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
""" """
if self.step_index is None: if self.step_index is None:
self._init_step_index(timestep) self._init_step_index(timestep)
...@@ -548,9 +570,14 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -548,9 +570,14 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index += 1 self._step_index += 1
if not return_dict: if not return_dict:
return (prev_sample,) return (
prev_sample,
pred_original_sample,
)
return SchedulerOutput(prev_sample=prev_sample) return KDPM2AncestralDiscreteSchedulerOutput(
prev_sample=prev_sample, pred_original_sample=pred_original_sample
)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
......
...@@ -13,20 +13,40 @@ ...@@ -13,20 +13,40 @@
# limitations under the License. # limitations under the License.
import math import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import is_scipy_available from ..utils import BaseOutput, is_scipy_available
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
if is_scipy_available(): if is_scipy_available():
import scipy.stats import scipy.stats
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->KDPM2Discrete
class KDPM2DiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
pred_original_sample: Optional[torch.Tensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar( def betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
...@@ -443,7 +463,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -443,7 +463,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep: Union[float, torch.Tensor], timestep: Union[float, torch.Tensor],
sample: Union[torch.Tensor, np.ndarray], sample: Union[torch.Tensor, np.ndarray],
return_dict: bool = True, return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]: ) -> Union[KDPM2DiscreteSchedulerOutput, Tuple]:
""" """
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise). process from the learned model outputs (most often the predicted noise).
...@@ -456,12 +476,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -456,12 +476,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.Tensor`): sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process. A current instance of a sample created by the diffusion process.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Whether or not to return a [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or
tuple.
Returns: Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a If return_dict is `True`, [`~schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteSchedulerOutput`] is
tuple is returned where the first element is the sample tensor. returned, otherwise a tuple is returned where the first element is the sample tensor.
""" """
if self.step_index is None: if self.step_index is None:
self._init_step_index(timestep) self._init_step_index(timestep)
...@@ -523,9 +544,12 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -523,9 +544,12 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
prev_sample = sample + derivative * dt prev_sample = sample + derivative * dt
if not return_dict: if not return_dict:
return (prev_sample,) return (
prev_sample,
pred_original_sample,
)
return SchedulerOutput(prev_sample=prev_sample) return KDPM2DiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment