Unverified Commit a3ae4661 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

schedulers add glide noising schedule (#2347)

parent c613288c
...@@ -46,6 +46,7 @@ class DDIMSchedulerOutput(BaseOutput): ...@@ -46,6 +46,7 @@ class DDIMSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.FloatTensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
""" """
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
...@@ -72,7 +73,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor ...@@ -72,7 +73,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
t1 = i / num_diffusion_timesteps t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas) return torch.tensor(betas, dtype=torch.float32)
class DDIMScheduler(SchedulerMixin, ConfigMixin): class DDIMScheduler(SchedulerMixin, ConfigMixin):
......
...@@ -25,6 +25,7 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -25,6 +25,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
""" """
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
......
...@@ -24,6 +24,7 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -24,6 +24,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
""" """
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
......
...@@ -24,6 +24,7 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -24,6 +24,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
""" """
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -45,6 +46,36 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput): ...@@ -45,6 +46,36 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.FloatTensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
...@@ -93,6 +124,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -93,6 +124,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.betas = ( self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
) )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -214,6 +248,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -214,6 +248,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip # * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -45,6 +46,36 @@ class EulerDiscreteSchedulerOutput(BaseOutput): ...@@ -45,6 +46,36 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.FloatTensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original
...@@ -97,6 +128,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -97,6 +128,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.betas = ( self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
) )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -245,7 +279,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -245,7 +279,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "original_sample": # NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
pred_original_sample = model_output pred_original_sample = model_output
elif self.config.prediction_type == "epsilon": elif self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output pred_original_sample = sample - sigma_hat * model_output
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -21,6 +22,36 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -21,6 +22,36 @@ from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original
...@@ -69,6 +100,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -69,6 +100,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.betas = ( self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
) )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -197,6 +231,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -197,6 +231,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
sample / (sigma_input**2 + 1) sample / (sigma_input**2 + 1)
) )
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -22,6 +23,36 @@ from ..utils import randn_tensor ...@@ -22,6 +23,36 @@ from ..utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
...@@ -71,6 +102,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -71,6 +102,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.betas = ( self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
) )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -254,6 +288,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -254,6 +288,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
sample / (sigma_input**2 + 1) sample / (sigma_input**2 + 1)
) )
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -21,6 +22,36 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -21,6 +22,36 @@ from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
...@@ -70,6 +101,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -70,6 +101,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self.betas = ( self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
) )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -237,6 +271,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -237,6 +271,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
sample / (sigma_input**2 + 1) sample / (sigma_input**2 + 1)
) )
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -43,6 +44,36 @@ class LMSDiscreteSchedulerOutput(BaseOutput): ...@@ -43,6 +44,36 @@ class LMSDiscreteSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.FloatTensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
...@@ -91,6 +122,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -91,6 +122,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.betas = ( self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
) )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -223,6 +257,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -223,6 +257,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip # * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
......
...@@ -24,6 +24,7 @@ from ..configuration_utils import ConfigMixin, register_to_config ...@@ -24,6 +24,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
""" """
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
......
...@@ -42,6 +42,7 @@ class RePaintSchedulerOutput(BaseOutput): ...@@ -42,6 +42,7 @@ class RePaintSchedulerOutput(BaseOutput):
pred_original_sample: torch.FloatTensor pred_original_sample: torch.FloatTensor
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
""" """
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
......
...@@ -43,6 +43,7 @@ class UnCLIPSchedulerOutput(BaseOutput): ...@@ -43,6 +43,7 @@ class UnCLIPSchedulerOutput(BaseOutput):
pred_original_sample: Optional[torch.FloatTensor] = None pred_original_sample: Optional[torch.FloatTensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
""" """
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
......
...@@ -89,7 +89,7 @@ class StableDiffusionLatentUpscalePipelineFastTests(PipelineTesterMixin, unittes ...@@ -89,7 +89,7 @@ class StableDiffusionLatentUpscalePipelineFastTests(PipelineTesterMixin, unittes
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4, latent_channels=4,
) )
scheduler = EulerDiscreteScheduler(prediction_type="original_sample") scheduler = EulerDiscreteScheduler(prediction_type="sample")
text_config = CLIPTextConfig( text_config = CLIPTextConfig(
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment