"vscode:/vscode.git/clone" did not exist on "f593bfd3c258b0ff2b7bdbabfb06ab5210b43a52"
Unverified Commit 33045382 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[DDIM, DDPM] fix add_noise (#648)

fix add noise
parent e5eed523
...@@ -282,7 +282,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -282,7 +282,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
timesteps = timesteps.to(self.alphas_cumprod.device) if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
......
...@@ -268,7 +268,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -268,7 +268,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.IntTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
timesteps = timesteps.to(self.alphas_cumprod.device) if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = sqrt_alpha_prod.flatten()
...@@ -276,7 +280,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -276,7 +280,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod.flatten() sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
......
...@@ -387,8 +387,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -387,8 +387,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if timesteps.device != original_samples.device: if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape): while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment