Commit 6b275fca authored by anton-l's avatar anton-l
Browse files

make PIL the default output type

parent 1b42732c
......@@ -28,7 +28,7 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
......@@ -55,5 +55,7 @@ class DDIMPipeline(DiffusionPipeline):
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
......@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="numpy"):
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
......
......@@ -30,7 +30,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
output_type="numpy",
output_type="pil",
):
# eta corresponds to η in paper and should be between [0, 1]
......
......@@ -13,7 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="numpy"
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
):
# eta corresponds to η in paper and should be between [0, 1]
......
......@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="numpy"):
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"):
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
if torch_device is None:
......
......@@ -11,7 +11,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.register_modules(model=model, scheduler=scheduler)
@torch.no_grad()
def __call__(self, num_inference_steps=2000, generator=None, output_type="numpy"):
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
......
......@@ -704,9 +704,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"]
image = ddpm(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0)
new_image = new_ddpm(generator=generator)["sample"]
new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
......@@ -722,9 +722,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"]
image = ddpm(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator)["sample"]
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
......@@ -735,10 +735,6 @@ class PipelineTesterMixin(unittest.TestCase):
pipe = DDIMPipeline.from_pretrained(model_path)
generator = torch.manual_seed(0)
images = pipe(generator=generator)["sample"]
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
images = pipe(generator=generator, output_type="numpy")["sample"]
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
......@@ -748,6 +744,11 @@ class PipelineTesterMixin(unittest.TestCase):
assert len(images) == 1
assert isinstance(images[0], PIL.Image.Image)
# use PIL by default
images = pipe(generator=generator)["sample"]
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
@slow
def test_ddpm_cifar10(self):
model_id = "google/ddpm-cifar10-32"
......@@ -759,7 +760,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"]
image = ddpm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
......@@ -777,7 +778,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"]
image = ddpm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
......@@ -795,7 +796,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0)["sample"]
image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
......@@ -812,7 +813,7 @@ class PipelineTesterMixin(unittest.TestCase):
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = pndm(generator=generator)["sample"]
image = pndm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
......@@ -826,7 +827,9 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20)["sample"]
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
"sample"
]
image_slice = image[0, -3:, -3:, -1]
......@@ -840,7 +843,7 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=1)["sample"]
image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
......@@ -861,7 +864,7 @@ class PipelineTesterMixin(unittest.TestCase):
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
torch.manual_seed(0)
image = sde_ve(num_inference_steps=300)["sample"]
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
......@@ -874,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256")
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5)["sample"]
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment