Unverified Commit 194ed794 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[PNDM] Stable diffusion (#186)

* [PNDM] Stable diffusino

* finish
parent 051b3463
...@@ -56,6 +56,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -56,6 +56,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_end=0.02, beta_end=0.02,
beta_schedule="linear", beta_schedule="linear",
tensor_format="pt", tensor_format="pt",
skip_prk_steps=False,
): ):
if beta_schedule == "linear": if beta_schedule == "linear":
...@@ -88,6 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -88,6 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps = None
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self._offset = 0
self.prk_timesteps = None self.prk_timesteps = None
self.plms_timesteps = None self.plms_timesteps = None
self.timesteps = None self.timesteps = None
...@@ -95,17 +97,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -95,17 +97,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps, offset=0):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self._timesteps = list( self._timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
) )
self._offset = offset
self._timesteps = [t + self._offset for t in self._timesteps]
if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
self.prk_timesteps = []
self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:]))
else:
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
self.timesteps = self.prk_timesteps + self.plms_timesteps self.timesteps = self.prk_timesteps + self.plms_timesteps
self.counter = 0 self.counter = 0
...@@ -117,7 +129,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,7 +129,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
): ):
if self.counter < len(self.prk_timesteps): if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample) return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
else: else:
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample) return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
...@@ -166,7 +178,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -166,7 +178,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution. times to approximate the solution.
""" """
if len(self.ets) < 3: if not self.config.skip_prk_steps and len(self.ets) < 3:
raise ValueError( raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run " f"{self.__class__} can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations " "in 'prk' mode for at least 12 iterations "
...@@ -175,9 +187,26 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -175,9 +187,26 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
) )
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
self.ets.append(model_output)
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) if self.counter != 1:
self.ets.append(model_output)
else:
prev_timestep = timestep
timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
if len(self.ets) == 1 and self.counter == 0:
model_output = model_output
self.cur_sample = sample
elif len(self.ets) == 1 and self.counter == 1:
model_output = (model_output + self.ets[-1]) / 2
sample = self.cur_sample
self.cur_sample = None
elif len(self.ets) == 2:
model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
elif len(self.ets) == 3:
model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
else:
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1 self.counter += 1
...@@ -197,8 +226,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -197,8 +226,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t # sample -> x_t
# model_output -> e_θ(x_t, t) # model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ) # prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[timestep + 1] alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1] alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
beta_prod_t = 1 - alpha_prod_t beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev beta_prod_t_prev = 1 - alpha_prod_t_prev
......
...@@ -843,6 +843,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -843,6 +843,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU") @unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
def test_stable_diffusion(self): def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers") sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
...@@ -857,7 +858,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -857,7 +858,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3) assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.898, 0.9194, 0.91, 0.8955, 0.915, 0.919, 0.9233, 0.9307, 0.8887]) expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment