"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3ad49eeeddc5b3a82540bd37ac133650d02ad93d"
Unverified Commit 878af0e1 authored by Partho's avatar Partho Committed by GitHub
Browse files

[Type Hint] DDPM schedulers (#349)

parent dea5ec50
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math import math
from typing import Union from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -51,14 +51,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -51,14 +51,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
num_train_timesteps=1000, num_train_timesteps: int = 1000,
beta_start=0.0001, beta_start: float = 0.0001,
beta_end=0.02, beta_end: float = 0.02,
beta_schedule="linear", beta_schedule: str = "linear",
trained_betas=None, trained_betas: Optional[np.ndarray] = None,
variance_type="fixed_small", variance_type: str = "fixed_small",
clip_sample=True, clip_sample: bool = True,
tensor_format="pt", tensor_format: str = "pt",
): ):
if trained_betas is not None: if trained_betas is not None:
...@@ -87,7 +87,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -87,7 +87,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.variance_type = variance_type self.variance_type = variance_type
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps: int):
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange( self.timesteps = np.arange(
...@@ -179,7 +179,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -179,7 +179,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return {"prev_sample": pred_prev_sample} return {"prev_sample": pred_prev_sample}
def add_noise(self, original_samples, noise, timesteps): def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment