Commit 29628acb authored by Patrick von Platen's avatar Patrick von Platen
Browse files

renaming of api

parent 9d2fc6b5
......@@ -45,15 +45,15 @@ class DDIMPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm.tqdm(self.scheduler.timesteps):
# 1. predict noise residual
# 1. predict noise model_output
with torch.no_grad():
residual = self.unet(image, t)
model_output = self.unet(image, t)
if isinstance(residual, dict):
residual = residual["sample"]
if isinstance(model_output, dict):
model_output = model_output["sample"]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(residual, t, image, eta)["prev_sample"]
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
return {"sample": image}
......@@ -42,15 +42,15 @@ class DDPMPipeline(DiffusionPipeline):
num_prediction_steps = len(self.scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual
# 1. predict noise model_output
with torch.no_grad():
residual = self.unet(image, t)
model_output = self.unet(image, t)
if isinstance(residual, dict):
residual = residual["sample"]
if isinstance(model_output, dict):
model_output = model_output["sample"]
# 2. predict previous mean of image x_t-1
pred_prev_image = self.scheduler.step(residual, t, image)["prev_sample"]
pred_prev_image = self.scheduler.step(model_output, t, image)["prev_sample"]
# 3. optionally sample variance
variance = 0
......
......@@ -36,14 +36,14 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm.tqdm(self.scheduler.timesteps):
residual = self.unet(image, t)
model_output = self.unet(image, t)
if isinstance(residual, dict):
residual = residual["sample"]
if isinstance(model_output, dict):
model_output = model_output["sample"]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(residual, t, image, eta)["prev_sample"]
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
# decode image with vae
image = self.vqvae.decode(image)
......
......@@ -45,21 +45,21 @@ class PNDMPipeline(DiffusionPipeline):
prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(prk_time_steps))):
t_orig = prk_time_steps[t]
residual = self.unet(image, t_orig)
model_output = self.unet(image, t_orig)
if isinstance(residual, dict):
residual = residual["sample"]
if isinstance(model_output, dict):
model_output = model_output["sample"]
image = self.scheduler.step_prk(residual, t, image, num_inference_steps)["prev_sample"]
image = self.scheduler.step_prk(model_output, t, image, num_inference_steps)["prev_sample"]
timesteps = self.scheduler.get_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(timesteps))):
t_orig = timesteps[t]
residual = self.unet(image, t_orig)
model_output = self.unet(image, t_orig)
if isinstance(residual, dict):
residual = residual["sample"]
if isinstance(model_output, dict):
model_output = model_output["sample"]
image = self.scheduler.step_plms(residual, t, image, num_inference_steps)["prev_sample"]
image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"]
return image
......@@ -112,11 +112,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
residual: Union[torch.FloatTensor, np.ndarray],
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
eta,
use_clipped_residual=False,
use_clipped_model_output=False,
generator=None,
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
......@@ -140,7 +140,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
if self.config.clip_sample:
......@@ -151,22 +151,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
if use_clipped_residual:
# the residual is always re-derived from the clipped x_0 in Glide
residual = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
if use_clipped_model_output:
# the model_output is always re-derived from the clipped x_0 in Glide
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
device = residual.device if torch.is_tensor(residual) else "cpu"
noise = torch.randn(residual.shape, generator=generator).to(device)
device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
if not torch.is_tensor(residual):
if not torch.is_tensor(model_output):
variance = variance.numpy()
prev_sample = prev_sample + variance
......
......@@ -116,7 +116,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
residual: Union[torch.FloatTensor, np.ndarray],
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
predict_epsilon=True,
......@@ -131,9 +131,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
pred_original_sample = residual
pred_original_sample = model_output
# 3. Clip "predicted x_0"
if self.config.clip_sample:
......
......@@ -85,7 +85,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.pndm_order = 4
# running values
self.cur_residual = 0
self.cur_model_output = 0
self.cur_sample = None
self.ets = []
self.prk_time_steps = {}
......@@ -130,7 +130,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
def step_prk(
self,
residual: Union[torch.FloatTensor, np.ndarray],
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
......@@ -142,25 +142,25 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
if t % 4 == 0:
self.cur_residual += 1 / 6 * residual
self.ets.append(residual)
self.cur_model_output += 1 / 6 * model_output
self.ets.append(model_output)
self.cur_sample = sample
elif (t - 1) % 4 == 0:
self.cur_residual += 1 / 3 * residual
self.cur_model_output += 1 / 3 * model_output
elif (t - 2) % 4 == 0:
self.cur_residual += 1 / 3 * residual
self.cur_model_output += 1 / 3 * model_output
elif (t - 3) % 4 == 0:
residual = self.cur_residual + 1 / 6 * residual
self.cur_residual = 0
model_output = self.cur_model_output + 1 / 6 * model_output
self.cur_model_output = 0
# cur_sample should not be `None`
cur_sample = self.cur_sample if self.cur_sample is not None else sample
return {"prev_sample": self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)}
return {"prev_sample": self.get_prev_sample(cur_sample, t_orig, t_orig_prev, model_output)}
def step_plms(
self,
residual: Union[torch.FloatTensor, np.ndarray],
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
......@@ -178,13 +178,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
t_orig = timesteps[t]
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
self.ets.append(residual)
self.ets.append(model_output)
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
return {"prev_sample": self.get_prev_sample(sample, t_orig, t_orig_prev, residual)}
return {"prev_sample": self.get_prev_sample(sample, t_orig, t_orig_prev, model_output)}
def get_prev_sample(self, sample, t_orig, t_orig_prev, residual):
def get_prev_sample(self, sample, t_orig, t_orig_prev, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(t−δ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
......@@ -195,7 +195,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# beta_prod_t -> (1 - α_t)
# beta_prod_t_prev -> (1 - α_(t−δ))
# sample -> x_t
# residual -> e_θ(x_t, t)
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
alpha_prod_t = self.alphas_cumprod[t_orig + 1]
alpha_prod_t_prev = self.alphas_cumprod[t_orig_prev + 1]
......@@ -209,12 +209,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
# corresponds to denominator of e_θ(x_t, t) in formula (9)
residual_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
alpha_prod_t * beta_prod_t * alpha_prod_t_prev
) ** (0.5)
# full formula (9)
prev_sample = sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * residual / residual_denom_coeff
prev_sample = sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
return prev_sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment