Unverified Commit 194ed794 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[PNDM] Stable diffusion (#186)

* [PNDM] Stable diffusino

* finish
parent 051b3463
......@@ -56,6 +56,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_end=0.02,
beta_schedule="linear",
tensor_format="pt",
skip_prk_steps=False,
):
if beta_schedule == "linear":
......@@ -88,6 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# setable values
self.num_inference_steps = None
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self._offset = 0
self.prk_timesteps = None
self.plms_timesteps = None
self.timesteps = None
......@@ -95,17 +97,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps):
def set_timesteps(self, num_inference_steps, offset=0):
self.num_inference_steps = num_inference_steps
self._timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
self._offset = offset
self._timesteps = [t + self._offset for t in self._timesteps]
if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
self.prk_timesteps = []
self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:]))
else:
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
self.timesteps = self.prk_timesteps + self.plms_timesteps
self.counter = 0
......@@ -117,7 +129,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
if self.counter < len(self.prk_timesteps):
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
else:
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
......@@ -166,7 +178,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
"""
if len(self.ets) < 3:
if not self.config.skip_prk_steps and len(self.ets) < 3:
raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
......@@ -175,9 +187,26 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
)
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
self.ets.append(model_output)
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
if self.counter != 1:
self.ets.append(model_output)
else:
prev_timestep = timestep
timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
if len(self.ets) == 1 and self.counter == 0:
model_output = model_output
self.cur_sample = sample
elif len(self.ets) == 1 and self.counter == 1:
model_output = (model_output + self.ets[-1]) / 2
sample = self.cur_sample
self.cur_sample = None
elif len(self.ets) == 2:
model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
elif len(self.ets) == 3:
model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
else:
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1
......@@ -197,8 +226,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep + 1]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1]
alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
......
......@@ -843,6 +843,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
prompt = "A painting of a squirrel eating a burger"
......@@ -857,7 +858,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.898, 0.9194, 0.91, 0.8955, 0.915, 0.919, 0.9233, 0.9307, 0.8887])
expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment