Unverified Commit 5effcd3e authored by Anand Kumar's avatar Anand Kumar Committed by GitHub
Browse files

[Bug fix] "previous_timestep()" in DDPM scheduling compatible with "trailing"...


[Bug fix] "previous_timestep()" in DDPM scheduling compatible with "trailing" and "linspace" options (#9384)

* Update scheduling_ddpm.py

* fix copies

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 619b9658
...@@ -548,16 +548,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -548,16 +548,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return self.config.num_train_timesteps return self.config.num_train_timesteps
def previous_timestep(self, timestep): def previous_timestep(self, timestep):
if self.custom_timesteps: if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1: if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1) prev_t = torch.tensor(-1)
else: else:
prev_t = self.timesteps[index + 1] prev_t = self.timesteps[index + 1]
else: else:
num_inference_steps = ( prev_t = timestep - 1
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
return prev_t return prev_t
...@@ -639,16 +639,12 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): ...@@ -639,16 +639,12 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep): def previous_timestep(self, timestep):
if self.custom_timesteps: if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1: if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1) prev_t = torch.tensor(-1)
else: else:
prev_t = self.timesteps[index + 1] prev_t = self.timesteps[index + 1]
else: else:
num_inference_steps = ( prev_t = timestep - 1
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
return prev_t return prev_t
...@@ -643,16 +643,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): ...@@ -643,16 +643,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep): def previous_timestep(self, timestep):
if self.custom_timesteps: if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1: if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1) prev_t = torch.tensor(-1)
else: else:
prev_t = self.timesteps[index + 1] prev_t = self.timesteps[index + 1]
else: else:
num_inference_steps = ( prev_t = timestep - 1
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
return prev_t return prev_t
...@@ -680,16 +680,12 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): ...@@ -680,16 +680,12 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep): def previous_timestep(self, timestep):
if self.custom_timesteps: if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1: if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1) prev_t = torch.tensor(-1)
else: else:
prev_t = self.timesteps[index + 1] prev_t = self.timesteps[index + 1]
else: else:
num_inference_steps = ( prev_t = timestep - 1
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
return prev_t return prev_t
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment