Commit 074d281a authored by William Berman's avatar William Berman Committed by Will Berman
Browse files

tests and additional scheduler fixes

parent 953c9d14
...@@ -171,6 +171,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -171,6 +171,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order self.model_outputs = [None] * solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
...@@ -181,14 +182,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -181,14 +182,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, optional): device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
self.num_inference_steps = num_inference_steps
timesteps = ( timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1] .round()[::-1][:-1]
.copy() .copy()
.astype(np.int64) .astype(np.int64)
) )
# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps)
self.model_outputs = [ self.model_outputs = [
None, None,
] * self.config.solver_order ] * self.config.solver_order
......
...@@ -194,21 +194,29 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -194,21 +194,29 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, optional): device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
self.num_inference_steps = num_inference_steps
timesteps = ( timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
.round()[::-1][:-1] .round()[::-1][:-1]
.copy() .copy()
.astype(np.int64) .astype(np.int64)
) )
# when num_inference_steps == num_train_timesteps, we can end up with
# duplicates in timesteps.
_, unique_indices = np.unique(timesteps, return_index=True)
timesteps = timesteps[np.sort(unique_indices)]
self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps = torch.from_numpy(timesteps).to(device)
self.num_inference_steps = len(timesteps)
self.model_outputs = [ self.model_outputs = [
None, None,
] * self.config.solver_order ] * self.config.solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
self.last_sample = None self.last_sample = None
if self.solver_p: if self.solver_p:
self.solver_p.set_timesteps(num_inference_steps, device=device) self.solver_p.set_timesteps(self.num_inference_steps, device=device)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
......
...@@ -243,3 +243,11 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): ...@@ -243,3 +243,11 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
sample = scheduler.step(residual, t, sample).prev_sample sample = scheduler.step(residual, t, sample).prev_sample
assert sample.dtype == torch.float16 assert sample.dtype == torch.float16
def test_unique_timesteps(self, **config):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
...@@ -229,3 +229,11 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): ...@@ -229,3 +229,11 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
sample = scheduler.step(residual, t, sample).prev_sample sample = scheduler.step(residual, t, sample).prev_sample
assert sample.dtype == torch.float16 assert sample.dtype == torch.float16
def test_unique_timesteps(self, **config):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment