Unverified Commit 5effcd3e authored by Anand Kumar's avatar Anand Kumar Committed by GitHub
Browse files

[Bug fix] "previous_timestep()" in DDPM scheduling compatible with "trailing"...


[Bug fix] "previous_timestep()" in DDPM scheduling compatible with "trailing" and "linspace" options (#9384)

* Update scheduling_ddpm.py

* fix copies

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 619b9658
......@@ -548,16 +548,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return self.config.num_train_timesteps
def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
prev_t = timestep - 1
return prev_t
......@@ -639,16 +639,12 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
prev_t = timestep - 1
return prev_t
......@@ -643,16 +643,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
prev_t = timestep - 1
return prev_t
......@@ -680,16 +680,12 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
prev_t = timestep - 1
return prev_t
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment