Unverified Commit be52be72 authored by Santiago Víquez's avatar Santiago Víquez Committed by GitHub
Browse files

[Type hint] scheduling lms discrete (#360)

* [Type hint] scheduling karras ve

* [Type hint] scheduling lms discrete
parent 3c1cdd33
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -27,13 +27,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -27,13 +27,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
num_train_timesteps=1000, num_train_timesteps: int = 1000,
beta_start=0.0001, beta_start: float = 0.0001,
beta_end=0.02, beta_end: float = 0.02,
beta_schedule="linear", beta_schedule: str = "linear",
trained_betas=None, trained_betas: Optional[np.ndarray] = None,
timestep_values=None, timestep_values: Optional[np.ndarray] = None,
tensor_format="pt", tensor_format: str = "pt",
): ):
""" """
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
...@@ -79,7 +79,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -79,7 +79,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return integrated_coeff return integrated_coeff
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps: int):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
...@@ -127,7 +127,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -127,7 +127,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample) return SchedulerOutput(prev_sample=prev_sample)
def add_noise(self, original_samples, noise, timesteps): def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
sigmas = self.match_shape(self.sigmas[timesteps], noise) sigmas = self.match_shape(self.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas noisy_samples = original_samples + noise * sigmas
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment