"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "df476d9f63891406db2d531aa5faf195193e3354"
Unverified Commit 3b9b9865 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Fix a bug in `add_noise` function (#6085)



* fix

* copies

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent b65928b5
...@@ -734,7 +734,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -734,7 +734,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -896,7 +896,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -896,7 +896,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -891,7 +891,16 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -891,7 +891,16 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -897,7 +897,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -897,7 +897,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
...@@ -828,7 +828,16 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -828,7 +828,16 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device) schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten() sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape): while len(sigma.shape) < len(original_samples.shape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment