Unverified Commit dea5ec50 authored by Partho's avatar Partho Committed by GitHub
Browse files

[Type hint] PNDM schedulers (#335)

* [Type hint] PNDM Schedulers

* ran make style

* updated timesteps type hint

* apply suggestions from code review

* ran make style

* removed unused import
parent 6c0ca5ef
...@@ -51,12 +51,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -51,12 +51,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
num_train_timesteps=1000, num_train_timesteps: int = 1000,
beta_start=0.0001, beta_start: float = 0.0001,
beta_end=0.02, beta_end: float = 0.02,
beta_schedule="linear", beta_schedule: str = "linear",
tensor_format="pt", tensor_format: str = "pt",
skip_prk_steps=False, skip_prk_steps: bool = False,
): ):
if beta_schedule == "linear": if beta_schedule == "linear":
...@@ -97,7 +97,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -97,7 +97,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps, offset=0): def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self._timesteps = list( self._timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
...@@ -264,7 +264,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -264,7 +264,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return prev_sample return prev_sample
def add_noise(self, original_samples, noise, timesteps): def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> torch.Tensor:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment