Commit 953c9d14 authored by William Berman's avatar William Berman Committed by Will Berman
Browse files

[bug fix] dpm multistep solver duplicate timesteps

parent 85f1c192
...@@ -192,14 +192,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -192,14 +192,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, optional): device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
self.num_inference_steps = num_inference_steps
timesteps = ( timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1] .round()[::-1][:-1]
.copy() .copy()
.astype(np.int64) .astype(np.int64)
) )
# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps)
self.model_outputs = [ self.model_outputs = [
None, None,
] * self.config.solver_order ] * self.config.solver_order
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment