Unverified Commit 2f25156c authored by hlky's avatar hlky Committed by GitHub
Browse files

LEditsPP - examples, check height/width, add tiling/slicing (#10471)

* LEditsPP - examples, check height/width, add tiling/slicing

* make style
parent 6da64065
......@@ -34,21 +34,19 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import PIL
>>> import requests
>>> import torch
>>> from io import BytesIO
>>> from diffusers import LEditsPPPipelineStableDiffusion
>>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
... "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
>>> image = load_image(img_url).convert("RGB")
>>> image = load_image(img_url).resize((512, 512))
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)
......@@ -152,7 +150,7 @@ class LeditsGaussianSmoothing:
# The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
......@@ -706,6 +704,35 @@ class LEditsPPPipelineStableDiffusion(
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
......@@ -1271,6 +1298,8 @@ class LEditsPPPipelineStableDiffusion(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s).
"""
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
raise ValueError("height and width must be a factor of 32.")
# Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor())
......@@ -1360,6 +1389,12 @@ class LEditsPPPipelineStableDiffusion(
image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
height, width = image.shape[-2:]
if height % 32 != 0 or width % 32 != 0:
raise ValueError(
"Image height and width must be a factor of 32. "
"Consider down-sampling the input using the `height` and `width` parameters"
)
resized = self.image_processor.postprocess(image=image, output_type="pil")
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
......
......@@ -72,25 +72,18 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> import PIL
>>> import requests
>>> from io import BytesIO
>>> from diffusers import LEditsPPPipelineStableDiffusionXL
>>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
... "stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")
>>> def download_image(url):
... response = requests.get(url)
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
>>> image = download_image(img_url)
>>> image = load_image(img_url).resize((1024, 1024))
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)
......@@ -197,7 +190,7 @@ class LeditsGaussianSmoothing:
# The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
......@@ -768,6 +761,35 @@ class LEditsPPPipelineStableDiffusionXL(
def num_timesteps(self):
return self._num_timesteps
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
def prepare_unet(self, attention_store, PnP: bool = False):
attn_procs = {}
......@@ -1401,6 +1423,12 @@ class LEditsPPPipelineStableDiffusionXL(
image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
height, width = image.shape[-2:]
if height % 32 != 0 or width % 32 != 0:
raise ValueError(
"Image height and width must be a factor of 32. "
"Consider down-sampling the input using the `height` and `width` parameters"
)
resized = self.image_processor.postprocess(image=image, output_type="pil")
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
......@@ -1439,6 +1467,10 @@ class LEditsPPPipelineStableDiffusionXL(
crops_coords_top_left: Tuple[int, int] = (0, 0),
num_zero_noise_steps: int = 3,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
resize_mode: Optional[str] = "default",
crops_coords: Optional[Tuple[int, int, int, int]] = None,
):
r"""
The function to the pipeline for image inversion as described by the [LEDITS++
......@@ -1486,6 +1518,8 @@ class LEditsPPPipelineStableDiffusionXL(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s).
"""
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
raise ValueError("height and width must be a factor of 32.")
# Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor())
......@@ -1510,7 +1544,14 @@ class LEditsPPPipelineStableDiffusionXL(
do_classifier_free_guidance = source_guidance_scale > 1.0
# 1. prepare image
x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype)
x0, resized = self.encode_image(
image,
dtype=self.text_encoder_2.dtype,
height=height,
width=width,
resize_mode=resize_mode,
crops_coords=crops_coords,
)
width = x0.shape[2] * self.vae_scale_factor
height = x0.shape[3] * self.vae_scale_factor
self.size = (height, width)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment