Unverified Commit 3c1cdd33 authored by Santiago Víquez's avatar Santiago Víquez Committed by GitHub
Browse files

[Type hint] scheduling karras ve (#359)

parent 07f8ebd5
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -54,13 +54,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -54,13 +54,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
sigma_min=0.02, sigma_min: float = 0.02,
sigma_max=100, sigma_max: float = 100,
s_noise=1.007, s_noise: float = 1.007,
s_churn=80, s_churn: float = 80,
s_min=0.05, s_min: float = 0.05,
s_max=50, s_max: float = 50,
tensor_format="pt", tensor_format: str = "pt",
): ):
""" """
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
...@@ -87,7 +87,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -87,7 +87,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps: int):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [ self.schedule = [
...@@ -98,7 +98,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): ...@@ -98,7 +98,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=self.tensor_format) self.set_format(tensor_format=self.tensor_format)
def add_noise_to_input(self, sample, sigma, generator=None): def add_noise_to_input(
self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
""" """
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment