Unverified Commit 5a59f9b7 authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

Add LDM Super Resolution pipeline (#1116)



* Add ldm super resolution pipeline

* style

* fix copies

* style

* fix doc

* Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* add doc

* address comments

* address comments

* fix doc

* minor

* add tests

* add tests

* load text encoder from subfolder

* fix test

* fix test

* style

* style

* handle mps latents

* unfix typo

* unfix typo

* Update tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* fix set_timesteps mps

* fix set_timesteps mps

* Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* style

* test 64x64 instead of 256x256
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent b93fe085
......@@ -33,6 +33,7 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_latent_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) | *Text-to-Image Generation* | - |
| [pipeline_latent_diffusion_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py) | *Super Resolution* | - |
## Examples:
......@@ -40,3 +41,7 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
## LDMTextToImagePipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline
- __call__
## LDMSuperResolutionPipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion_superresolution.LDMSuperResolutionPipeline
- __call__
......@@ -35,6 +35,7 @@ if is_torch_available():
DDPMPipeline,
KarrasVePipeline,
LDMPipeline,
LDMSuperResolutionPipeline,
PNDMPipeline,
RePaintPipeline,
ScoreSdeVePipeline,
......
......@@ -5,6 +5,7 @@ if is_torch_available():
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .repaint import RePaintPipeline
......
# flake8: noqa
from ...utils import is_transformers_available
from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline
if is_transformers_available():
......
import inspect
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
import PIL
from ...models import UNet2DModel, VQModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
class LDMSuperResolutionPipeline(DiffusionPipeline):
r"""
A pipeline for image super-resolution using Latent
This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
vqvae ([`VQModel`]):
Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations.
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
[`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`].
"""
def __init__(
self,
vqvae: VQModel,
unet: UNet2DModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
):
super().__init__()
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
init_image: Union[torch.Tensor, PIL.Image.Image],
batch_size: Optional[int] = 1,
num_inference_steps: Optional[int] = 100,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[Tuple, ImagePipelineOutput]:
r"""
Args:
init_image (`torch.Tensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
batch_size (`int`, *optional*, defaults to 1):
Number of images to generate.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if isinstance(init_image, PIL.Image.Image):
batch_size = 1
elif isinstance(init_image, torch.Tensor):
batch_size = init_image.shape[0]
else:
raise ValueError(
f"`init_image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(init_image)}"
)
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
height, width = init_image.shape[-2:]
# in_channels should be 6: 3 for latents, 3 for low resolution image
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
latents_dtype = next(self.unet.parameters()).dtype
if self.device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype)
latents = latents.to(self.device)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
init_image = init_image.to(device=self.device, dtype=latents_dtype)
# set timesteps and move to the correct device
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps_tensor = self.scheduler.timesteps
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwargs = {}
if accepts_eta:
extra_kwargs["eta"] = eta
for t in self.progress_bar(timesteps_tensor):
# concat latents and low resolution image in the channel dimension.
latents_input = torch.cat([latents, init_image], dim=1)
latents_input = self.scheduler.scale_model_input(latents_input, t)
# predict the noise residual
noise_pred = self.unet(latents_input, t).sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
# decode the image latents with the VQVAE
image = self.vqvae.decode(latents).sample
image = torch.clamp(image, -1.0, 1.0)
image = image / 2 + 0.5
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
......@@ -227,6 +227,21 @@ class LDMPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class LDMSuperResolutionPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch"]
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
import torch
import PIL
from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel
from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@property
def dummy_image(self):
batch_size = 1
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image
@property
def dummy_uncond_unet(self):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=6,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
return model
@property
def dummy_vq_model(self):
torch.manual_seed(0)
model = VQModel(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=3,
)
return model
def test_inference_superresolution(self):
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler()
vqvae = self.dummy_vq_model
ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
init_image = self.dummy_image.to(torch_device)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm(init_image, generator=generator, num_inference_steps=1, output_type="numpy").images
generator = torch.manual_seed(0)
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.8634, 0.8186, 0.6416, 0.6846, 0.4427, 0.5676, 0.4679, 0.6247, 0.5176])
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
@slow
@require_torch
class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase):
def test_inference_superresolution(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/vq_diffusion/teddy_bear_pool.png"
)
init_image = init_image.resize((64, 64), resample=PIL.Image.LANCZOS)
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ldm(init_image, generator=generator, num_inference_steps=20, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.7418, 0.7472, 0.7424, 0.7422, 0.7463, 0.726, 0.7382, 0.7248, 0.6828])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment