Commit 3a32b8c9 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

align API

parent c3a15437
......@@ -27,6 +27,7 @@ class DDIMPipeline(DiffusionPipeline):
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
......@@ -46,11 +47,7 @@ class DDIMPipeline(DiffusionPipeline):
for t in tqdm(self.scheduler.timesteps):
# 1. predict noise model_output
with torch.no_grad():
model_output = self.unet(image, t)
if isinstance(model_output, dict):
model_output = model_output["sample"]
model_output = self.unet(image, t)["sample"]
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
......
......@@ -27,6 +27,7 @@ class DDPMPipeline(DiffusionPipeline):
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
......@@ -45,11 +46,7 @@ class DDPMPipeline(DiffusionPipeline):
for t in tqdm(self.scheduler.timesteps):
# 1. predict noise model_output
with torch.no_grad():
model_output = self.unet(image, t)
if isinstance(model_output, dict):
model_output = model_output["sample"]
model_output = self.unet(image, t)["sample"]
# 2. predict previous mean of image x_t-1
pred_prev_image = self.scheduler.step(model_output, t, image)["prev_sample"]
......@@ -63,4 +60,4 @@ class DDPMPipeline(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
return image
return {"sample": image}
......@@ -27,6 +27,7 @@ class PNDMPipeline(DiffusionPipeline):
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
......@@ -45,21 +46,15 @@ class PNDMPipeline(DiffusionPipeline):
prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps)
for t in tqdm(range(len(prk_time_steps))):
t_orig = prk_time_steps[t]
model_output = self.unet(image, t_orig)
if isinstance(model_output, dict):
model_output = model_output["sample"]
model_output = self.unet(image, t_orig)["sample"]
image = self.scheduler.step_prk(model_output, t, image, num_inference_steps)["prev_sample"]
timesteps = self.scheduler.get_time_steps(num_inference_steps)
for t in tqdm(range(len(timesteps))):
t_orig = timesteps[t]
model_output = self.unet(image, t_orig)
if isinstance(model_output, dict):
model_output = model_output["sample"]
model_output = self.unet(image, t_orig)["sample"]
image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"]
return image
return {"sample": image}
......@@ -665,9 +665,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
image = ddpm(generator=generator)["sample"]
generator = generator.manual_seed(0)
new_image = new_ddpm(generator=generator)
new_image = new_ddpm(generator=generator)["sample"]
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
......@@ -683,9 +683,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
image = ddpm(generator=generator)["sample"]
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator)
new_image = ddpm_from_hub(generator=generator)["sample"]
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
......@@ -700,7 +700,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
image = ddpm(generator=generator)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
......@@ -759,7 +759,7 @@ class PipelineTesterMixin(unittest.TestCase):
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = pndm(generator=generator)
image = pndm(generator=generator)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment