Unverified Commit 1b42732c authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

PIL-ify the pipeline outputs (#111)

parent 9e9d2dbc
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
from typing import Optional, Union from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging from .utils import DIFFUSERS_CACHE, logging
...@@ -189,3 +190,15 @@ class DiffusionPipeline(ConfigMixin): ...@@ -189,3 +190,15 @@ class DiffusionPipeline(ConfigMixin):
# 5. Instantiate the pipeline # 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs) model = pipeline_class(**init_kwargs)
return model return model
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
...@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None): def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="numpy"):
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -56,5 +56,7 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -56,5 +56,7 @@ class DDPMPipeline(DiffusionPipeline):
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image} return {"sample": image}
...@@ -30,6 +30,7 @@ class LatentDiffusionPipeline(DiffusionPipeline): ...@@ -30,6 +30,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
eta=0.0, eta=0.0,
guidance_scale=1.0, guidance_scale=1.0,
num_inference_steps=50, num_inference_steps=50,
output_type="numpy",
): ):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
...@@ -86,6 +87,8 @@ class LatentDiffusionPipeline(DiffusionPipeline): ...@@ -86,6 +87,8 @@ class LatentDiffusionPipeline(DiffusionPipeline):
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image} return {"sample": image}
......
...@@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): ...@@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="numpy"
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
num_inference_steps=50,
): ):
# eta corresponds to η in paper and should be between [0, 1] # eta corresponds to η in paper and should be between [0, 1]
...@@ -47,5 +42,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): ...@@ -47,5 +42,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image} return {"sample": image}
...@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="numpy"):
# For more information on the sampling method you can take a look at Algorithm 2 of # For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf # the official paper: https://arxiv.org/pdf/2202.09778.pdf
if torch_device is None: if torch_device is None:
...@@ -59,5 +59,7 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -59,5 +59,7 @@ class PNDMPipeline(DiffusionPipeline):
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image} return {"sample": image}
...@@ -11,7 +11,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -11,7 +11,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.register_modules(model=model, scheduler=scheduler) self.register_modules(model=model, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, num_inference_steps=2000, generator=None): def __call__(self, num_inference_steps=2000, generator=None, output_type="numpy"):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size img_size = self.model.config.image_size
...@@ -47,5 +47,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -47,5 +47,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
sample = sample.clamp(0, 1) sample = sample.clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy() sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
sample = self.numpy_to_pil(sample)
return {"sample": sample} return {"sample": sample}
...@@ -18,11 +18,11 @@ import inspect ...@@ -18,11 +18,11 @@ import inspect
import math import math
import tempfile import tempfile
import unittest import unittest
from atexit import register
import numpy as np import numpy as np
import torch import torch
import PIL
from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -728,6 +728,26 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -728,6 +728,26 @@ class PipelineTesterMixin(unittest.TestCase):
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path)
generator = torch.manual_seed(0)
images = pipe(generator=generator)["sample"]
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
images = pipe(generator=generator, output_type="numpy")["sample"]
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
images = pipe(generator=generator, output_type="pil")["sample"]
assert isinstance(images, list)
assert len(images) == 1
assert isinstance(images[0], PIL.Image.Image)
@slow @slow
def test_ddpm_cifar10(self): def test_ddpm_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment