Unverified Commit 33045382 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[DDIM, DDPM] fix add_noise (#648)

fix add noise
parent e5eed523
......@@ -282,7 +282,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
timesteps = timesteps.to(self.alphas_cumprod.device)
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
......
......@@ -268,7 +268,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
timesteps = timesteps.to(self.alphas_cumprod.device)
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
......@@ -276,7 +280,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
......
......@@ -387,8 +387,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment