Commit 7e11392d authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix ddpm scheduler

parent 1f49a343
......@@ -51,13 +51,7 @@ class DDPMPipeline(DiffusionPipeline):
# 2. predict previous mean of image x_t-1
pred_prev_image = self.scheduler.step(model_output, t, image)["prev_sample"]
# 3. optionally sample variance
variance = 0
if t > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.scheduler.get_variance(t).sqrt() * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# 3. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image
return {"sample": image}
......@@ -101,7 +101,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
)[::-1].copy()
self.set_format(tensor_format=self.tensor_format)
def get_variance(self, t, variance_type=None):
def _get_variance(self, t, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
......@@ -133,6 +133,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
predict_epsilon=True,
generator=None,
):
t = timestep
# 1. compute alphas, betas
......@@ -161,6 +162,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
variance = 0
if t > 0:
noise = torch.randn(model_output.shape, generator=generator).to(model_output.device)
variance = self._get_variance(t).sqrt() * noise
pred_prev_sample = pred_prev_sample + variance
return {"prev_sample": pred_prev_sample}
def add_noise(self, original_samples, noise, timesteps):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment