Unverified Commit 2f25156c authored by hlky's avatar hlky Committed by GitHub
Browse files

LEditsPP - examples, check height/width, add tiling/slicing (#10471)

* LEditsPP - examples, check height/width, add tiling/slicing

* make style
parent 6da64065
...@@ -34,21 +34,19 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -34,21 +34,19 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
>>> import PIL
>>> import requests
>>> import torch >>> import torch
>>> from io import BytesIO
>>> from diffusers import LEditsPPPipelineStableDiffusion >>> from diffusers import LEditsPPPipelineStableDiffusion
>>> from diffusers.utils import load_image >>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained( >>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ... "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
... ) ... )
>>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda") >>> pipe = pipe.to("cuda")
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png" >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
>>> image = load_image(img_url).convert("RGB") >>> image = load_image(img_url).resize((512, 512))
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1) >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)
...@@ -152,7 +150,7 @@ class LeditsGaussianSmoothing: ...@@ -152,7 +150,7 @@ class LeditsGaussianSmoothing:
# The gaussian kernel is the product of the gaussian function of each dimension. # The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1 kernel = 1
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids): for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2 mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
...@@ -706,6 +704,35 @@ class LEditsPPPipelineStableDiffusion( ...@@ -706,6 +704,35 @@ class LEditsPPPipelineStableDiffusion(
def cross_attention_kwargs(self): def cross_attention_kwargs(self):
return self._cross_attention_kwargs return self._cross_attention_kwargs
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -1271,6 +1298,8 @@ class LEditsPPPipelineStableDiffusion( ...@@ -1271,6 +1298,8 @@ class LEditsPPPipelineStableDiffusion(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s) [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s). and respective VAE reconstruction(s).
""" """
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
raise ValueError("height and width must be a factor of 32.")
# Reset attn processor, we do not want to store attn maps during inversion # Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor()) self.unet.set_attn_processor(AttnProcessor())
...@@ -1360,6 +1389,12 @@ class LEditsPPPipelineStableDiffusion( ...@@ -1360,6 +1389,12 @@ class LEditsPPPipelineStableDiffusion(
image = self.image_processor.preprocess( image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
) )
height, width = image.shape[-2:]
if height % 32 != 0 or width % 32 != 0:
raise ValueError(
"Image height and width must be a factor of 32. "
"Consider down-sampling the input using the `height` and `width` parameters"
)
resized = self.image_processor.postprocess(image=image, output_type="pil") resized = self.image_processor.postprocess(image=image, output_type="pil")
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5: if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
......
...@@ -72,25 +72,18 @@ EXAMPLE_DOC_STRING = """ ...@@ -72,25 +72,18 @@ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
>>> import torch >>> import torch
>>> import PIL
>>> import requests
>>> from io import BytesIO
>>> from diffusers import LEditsPPPipelineStableDiffusionXL >>> from diffusers import LEditsPPPipelineStableDiffusionXL
>>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained( >>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 ... "stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=torch.float16
... ) ... )
>>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda") >>> pipe = pipe.to("cuda")
>>> def download_image(url):
... response = requests.get(url)
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg" >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
>>> image = download_image(img_url) >>> image = load_image(img_url).resize((1024, 1024))
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2) >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)
...@@ -197,7 +190,7 @@ class LeditsGaussianSmoothing: ...@@ -197,7 +190,7 @@ class LeditsGaussianSmoothing:
# The gaussian kernel is the product of the gaussian function of each dimension. # The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1 kernel = 1
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids): for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2 mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
...@@ -768,6 +761,35 @@ class LEditsPPPipelineStableDiffusionXL( ...@@ -768,6 +761,35 @@ class LEditsPPPipelineStableDiffusionXL(
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
def prepare_unet(self, attention_store, PnP: bool = False): def prepare_unet(self, attention_store, PnP: bool = False):
attn_procs = {} attn_procs = {}
...@@ -1401,6 +1423,12 @@ class LEditsPPPipelineStableDiffusionXL( ...@@ -1401,6 +1423,12 @@ class LEditsPPPipelineStableDiffusionXL(
image = self.image_processor.preprocess( image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
) )
height, width = image.shape[-2:]
if height % 32 != 0 or width % 32 != 0:
raise ValueError(
"Image height and width must be a factor of 32. "
"Consider down-sampling the input using the `height` and `width` parameters"
)
resized = self.image_processor.postprocess(image=image, output_type="pil") resized = self.image_processor.postprocess(image=image, output_type="pil")
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5: if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
...@@ -1439,6 +1467,10 @@ class LEditsPPPipelineStableDiffusionXL( ...@@ -1439,6 +1467,10 @@ class LEditsPPPipelineStableDiffusionXL(
crops_coords_top_left: Tuple[int, int] = (0, 0), crops_coords_top_left: Tuple[int, int] = (0, 0),
num_zero_noise_steps: int = 3, num_zero_noise_steps: int = 3,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
resize_mode: Optional[str] = "default",
crops_coords: Optional[Tuple[int, int, int, int]] = None,
): ):
r""" r"""
The function to the pipeline for image inversion as described by the [LEDITS++ The function to the pipeline for image inversion as described by the [LEDITS++
...@@ -1486,6 +1518,8 @@ class LEditsPPPipelineStableDiffusionXL( ...@@ -1486,6 +1518,8 @@ class LEditsPPPipelineStableDiffusionXL(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s) [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s). and respective VAE reconstruction(s).
""" """
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
raise ValueError("height and width must be a factor of 32.")
# Reset attn processor, we do not want to store attn maps during inversion # Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor()) self.unet.set_attn_processor(AttnProcessor())
...@@ -1510,7 +1544,14 @@ class LEditsPPPipelineStableDiffusionXL( ...@@ -1510,7 +1544,14 @@ class LEditsPPPipelineStableDiffusionXL(
do_classifier_free_guidance = source_guidance_scale > 1.0 do_classifier_free_guidance = source_guidance_scale > 1.0
# 1. prepare image # 1. prepare image
x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype) x0, resized = self.encode_image(
image,
dtype=self.text_encoder_2.dtype,
height=height,
width=width,
resize_mode=resize_mode,
crops_coords=crops_coords,
)
width = x0.shape[2] * self.vae_scale_factor width = x0.shape[2] * self.vae_scale_factor
height = x0.shape[3] * self.vae_scale_factor height = x0.shape[3] * self.vae_scale_factor
self.size = (height, width) self.size = (height, width)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment