Unverified Commit 5b11c5dc authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix the add_noise function for dpm-multi et al (#5158)



* remove to _device() for sigmas

* update add_noise to use simgas

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 310cf328
...@@ -243,8 +243,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -243,8 +243,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
self.num_inference_steps = len(timesteps) self.num_inference_steps = len(timesteps)
...@@ -707,12 +707,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -707,12 +707,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.FloatTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
...@@ -730,7 +730,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -730,7 +730,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1) sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
...@@ -263,8 +263,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -263,8 +263,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
self.num_inference_steps = len(timesteps) self.num_inference_steps = len(timesteps)
...@@ -840,12 +840,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -840,12 +840,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.FloatTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
...@@ -863,7 +862,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -863,7 +862,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1) sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
...@@ -274,7 +274,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -274,7 +274,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
_, unique_indices = np.unique(timesteps, return_index=True) _, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)] timesteps = timesteps[np.sort(unique_indices)]
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
self.num_inference_steps = len(timesteps) self.num_inference_steps = len(timesteps)
...@@ -858,12 +858,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -858,12 +858,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.FloatTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
...@@ -881,7 +881,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -881,7 +881,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1) sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
...@@ -275,7 +275,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -275,7 +275,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
self.model_outputs = [None] * self.config.solver_order self.model_outputs = [None] * self.config.solver_order
self.sample = None self.sample = None
...@@ -870,12 +870,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -870,12 +870,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.FloatTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
...@@ -893,7 +893,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -893,7 +893,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1) sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
...@@ -254,8 +254,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -254,8 +254,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
self.num_inference_steps = len(timesteps) self.num_inference_steps = len(timesteps)
...@@ -801,12 +801,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -801,12 +801,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
return sample return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
noise: torch.FloatTensor, noise: torch.FloatTensor,
timesteps: torch.FloatTensor, timesteps: torch.IntTensor,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples # Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
...@@ -824,7 +824,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -824,7 +824,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1) sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples return noisy_samples
def __len__(self): def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment