Unverified Commit b56f1027 authored by Nathan Lambert's avatar Nathan Lambert Committed by GitHub
Browse files

Fix scheduler inference steps error with power of 3 (#466)

* initial attempt at solving

* fix pndm power of 3 inference_step

* add power of 3 test

* fix index in pndm test, remove ddim test

* add comments, change to round()
parent da990633
...@@ -145,9 +145,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -145,9 +145,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange( step_ratio = self.config.num_train_timesteps // self.num_inference_steps
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio
)[::-1].copy() # casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
self.timesteps += offset self.timesteps += offset
self.set_format(tensor_format=self.tensor_format) self.set_format(tensor_format=self.tensor_format)
......
...@@ -143,9 +143,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -143,9 +143,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
""" """
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self._timesteps = list( step_ratio = self.config.num_train_timesteps // self.num_inference_steps
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) # creates integer timesteps by multiplying by ratio
) # casting to int to avoid issues when num_inference_step is power of 3
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist()
self._offset = offset self._offset = offset
self._timesteps = np.array([t + self._offset for t in self._timesteps]) self._timesteps = np.array([t + self._offset for t in self._timesteps])
......
...@@ -379,7 +379,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -379,7 +379,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
def test_inference_steps(self): def test_inference_steps(self):
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
self.check_over_forward(num_inference_steps=num_inference_steps) self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_eta(self): def test_eta(self):
for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]): for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
...@@ -622,6 +622,23 @@ class PNDMSchedulerTest(SchedulerCommonTest): ...@@ -622,6 +622,23 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_pow_of_3_inference_steps(self):
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
num_inference_steps = 27
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# before power of 3 fix, would error on first step, so we only need to do two
for i, t in enumerate(scheduler.prk_timesteps[:2]):
sample = scheduler.step_prk(residual, t, sample).prev_sample
def test_inference_plms_no_past_residuals(self): def test_inference_plms_no_past_residuals(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
scheduler_class = self.scheduler_classes[0] scheduler_class = self.scheduler_classes[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment