Unverified Commit c72e3430 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[PNDM in LDM pipeline] use inspect in pipeline instead of unused kwargs (#167)

use inspect instead of unused kwargs
parent 3228eb16
import inspect
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -59,6 +60,12 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -59,6 +60,12 @@ class LDMTextToImagePipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}
if not accepts_eta:
extra_kwrags["eta"] = eta
for t in tqdm(self.scheduler.timesteps): for t in tqdm(self.scheduler.timesteps):
if guidance_scale == 1.0: if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance # guidance_scale of 1 means no guidance
...@@ -79,7 +86,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -79,7 +86,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, eta=eta)["prev_sample"] latents = self.scheduler.step(noise_pred, t, latents, **extra_kwrags)["prev_sample"]
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
......
import inspect
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
...@@ -31,11 +33,17 @@ class LDMPipeline(DiffusionPipeline): ...@@ -31,11 +33,17 @@ class LDMPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwrags = {}
if not accepts_eta:
extra_kwrags["eta"] = eta
for t in tqdm(self.scheduler.timesteps): for t in tqdm(self.scheduler.timesteps):
# predict the noise residual # predict the noise residual
noise_prediction = self.unet(latents, t)["sample"] noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"] latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwrags)["prev_sample"]
# decode the image latents with the VAE # decode the image latents with the VAE
image = self.vqvae.decode(latents) image = self.vqvae.decode(latents)
......
...@@ -116,7 +116,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -116,7 +116,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
model_output: Union[torch.FloatTensor, np.ndarray], model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int, timestep: int,
sample: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray],
**kwargs,
): ):
if self.counter < len(self.prk_timesteps): if self.counter < len(self.prk_timesteps):
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample) return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment