Commit 6b275fca authored by anton-l's avatar anton-l
Browse files

make PIL the default output type

parent 1b42732c
...@@ -28,7 +28,7 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -28,7 +28,7 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50): def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -55,5 +55,7 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -55,5 +55,7 @@ class DDIMPipeline(DiffusionPipeline):
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image} return {"sample": image}
...@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="numpy"): def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"):
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
......
...@@ -30,7 +30,7 @@ class LatentDiffusionPipeline(DiffusionPipeline): ...@@ -30,7 +30,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
eta=0.0, eta=0.0,
guidance_scale=1.0, guidance_scale=1.0,
num_inference_steps=50, num_inference_steps=50,
output_type="numpy", output_type="pil",
): ):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
......
...@@ -13,7 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): ...@@ -13,7 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="numpy" self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
): ):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
......
...@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="numpy"): def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"):
# For more information on the sampling method you can take a look at Algorithm 2 of # For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf # the official paper: https://arxiv.org/pdf/2202.09778.pdf
if torch_device is None: if torch_device is None:
......
...@@ -11,7 +11,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -11,7 +11,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.register_modules(model=model, scheduler=scheduler) self.register_modules(model=model, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, num_inference_steps=2000, generator=None, output_type="numpy"): def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size img_size = self.model.config.image_size
......
...@@ -704,9 +704,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -704,9 +704,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"] image = ddpm(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
new_image = new_ddpm(generator=generator)["sample"] new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
...@@ -722,9 +722,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -722,9 +722,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"] image = ddpm(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator)["sample"] new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
...@@ -735,10 +735,6 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -735,10 +735,6 @@ class PipelineTesterMixin(unittest.TestCase):
pipe = DDIMPipeline.from_pretrained(model_path) pipe = DDIMPipeline.from_pretrained(model_path)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
images = pipe(generator=generator)["sample"]
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
images = pipe(generator=generator, output_type="numpy")["sample"] images = pipe(generator=generator, output_type="numpy")["sample"]
assert images.shape == (1, 32, 32, 3) assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray) assert isinstance(images, np.ndarray)
...@@ -748,6 +744,11 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -748,6 +744,11 @@ class PipelineTesterMixin(unittest.TestCase):
assert len(images) == 1 assert len(images) == 1
assert isinstance(images[0], PIL.Image.Image) assert isinstance(images[0], PIL.Image.Image)
# use PIL by default
images = pipe(generator=generator)["sample"]
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
@slow @slow
def test_ddpm_cifar10(self): def test_ddpm_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
...@@ -759,7 +760,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -759,7 +760,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"] image = ddpm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -777,7 +778,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -777,7 +778,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"] image = ddpm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -795,7 +796,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -795,7 +796,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddim = DDIMPipeline(unet=unet, scheduler=scheduler) ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0)["sample"] image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -812,7 +813,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -812,7 +813,7 @@ class PipelineTesterMixin(unittest.TestCase):
pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = pndm(generator=generator)["sample"] image = pndm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -826,7 +827,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -826,7 +827,9 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20)["sample"] image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
"sample"
]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -840,7 +843,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -840,7 +843,7 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=1)["sample"] image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -861,7 +864,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -861,7 +864,7 @@ class PipelineTesterMixin(unittest.TestCase):
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
torch.manual_seed(0) torch.manual_seed(0)
image = sde_ve(num_inference_steps=300)["sample"] image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
...@@ -874,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -874,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256") ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256")
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5)["sample"] image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment