Unverified Commit 59900147 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[WIP]Vae preprocessor refactor (PR1) (#3557)



VaeImageProcessor.preprocess refactor

* refactored VaeImageProcessor 
   -  allow passing optional height and width argument to resize()
   - add convert_to_rgb
* refactored prepare_latents method for img2img pipelines so that if we pass latents directly as image input, it will not encode it again
* added a test in test_pipelines_common.py to test latents as image inputs
* refactored img2img pipelines that accept latents as image: 
   - controlnet img2img, stable diffusion img2img , instruct_pix2pix

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 1a6a647e
...@@ -30,7 +30,8 @@ class VaeImageProcessor(ConfigMixin): ...@@ -30,7 +30,8 @@ class VaeImageProcessor(ConfigMixin):
Args: Args:
do_resize (`bool`, *optional*, defaults to `True`): do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
`height` and `width` arguments from `preprocess` method
vae_scale_factor (`int`, *optional*, defaults to `8`): vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
factor. factor.
...@@ -38,6 +39,8 @@ class VaeImageProcessor(ConfigMixin): ...@@ -38,6 +39,8 @@ class VaeImageProcessor(ConfigMixin):
Resampling filter to use when resizing the image. Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`): do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1] Whether to normalize the image to [-1,1]
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
""" """
config_name = CONFIG_NAME config_name = CONFIG_NAME
...@@ -49,11 +52,12 @@ class VaeImageProcessor(ConfigMixin): ...@@ -49,11 +52,12 @@ class VaeImageProcessor(ConfigMixin):
vae_scale_factor: int = 8, vae_scale_factor: int = 8,
resample: str = "lanczos", resample: str = "lanczos",
do_normalize: bool = True, do_normalize: bool = True,
do_convert_rgb: bool = False,
): ):
super().__init__() super().__init__()
@staticmethod @staticmethod
def numpy_to_pil(images): def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
""" """
Convert a numpy image or a batch of images to a PIL image. Convert a numpy image or a batch of images to a PIL image.
""" """
...@@ -69,7 +73,19 @@ class VaeImageProcessor(ConfigMixin): ...@@ -69,7 +73,19 @@ class VaeImageProcessor(ConfigMixin):
return pil_images return pil_images
@staticmethod @staticmethod
def numpy_to_pt(images): def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
"""
Convert a PIL image or a list of PIL images to numpy arrays.
"""
if not isinstance(images, list):
images = [images]
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
images = np.stack(images, axis=0)
return images
@staticmethod
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
""" """
Convert a numpy image to a pytorch tensor Convert a numpy image to a pytorch tensor
""" """
...@@ -80,7 +96,7 @@ class VaeImageProcessor(ConfigMixin): ...@@ -80,7 +96,7 @@ class VaeImageProcessor(ConfigMixin):
return images return images
@staticmethod @staticmethod
def pt_to_numpy(images): def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
""" """
Convert a pytorch tensor to a numpy image Convert a pytorch tensor to a numpy image
""" """
...@@ -101,18 +117,39 @@ class VaeImageProcessor(ConfigMixin): ...@@ -101,18 +117,39 @@ class VaeImageProcessor(ConfigMixin):
""" """
return (images / 2 + 0.5).clamp(0, 1) return (images / 2 + 0.5).clamp(0, 1)
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: @staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
"""
Converts an image to RGB format.
"""
image = image.convert("RGB")
return image
def resize(
self,
image: PIL.Image.Image,
height: Optional[int] = None,
width: Optional[int] = None,
) -> PIL.Image.Image:
""" """
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
""" """
w, h = images.size if height is None:
w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor height = image.height
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample]) if width is None:
return images width = image.width
width, height = (
x - x % self.config.vae_scale_factor for x in (width, height)
) # resize to integer multiple of vae_scale_factor
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
return image
def preprocess( def preprocess(
self, self,
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
height: Optional[int] = None,
width: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors" Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
...@@ -126,10 +163,11 @@ class VaeImageProcessor(ConfigMixin): ...@@ -126,10 +163,11 @@ class VaeImageProcessor(ConfigMixin):
) )
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
if self.config.do_convert_rgb:
image = [self.convert_to_rgb(i) for i in image]
if self.config.do_resize: if self.config.do_resize:
image = [self.resize(i) for i in image] image = [self.resize(i, height, width) for i in image]
image = [np.array(i).astype(np.float32) / 255.0 for i in image] image = self.pil_to_numpy(image) # to np
image = np.stack(image, axis=0) # to np
image = self.numpy_to_pt(image) # to pt image = self.numpy_to_pt(image) # to pt
elif isinstance(image[0], np.ndarray): elif isinstance(image[0], np.ndarray):
...@@ -146,7 +184,12 @@ class VaeImageProcessor(ConfigMixin): ...@@ -146,7 +184,12 @@ class VaeImageProcessor(ConfigMixin):
elif isinstance(image[0], torch.Tensor): elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, _, height, width = image.shape _, channel, height, width = image.shape
# don't need any preprocess if the image is latents
if channel == 4:
return image
if self.config.do_resize and ( if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
): ):
......
...@@ -69,6 +69,11 @@ EXAMPLE_DOC_STRING = """ ...@@ -69,6 +69,11 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
...@@ -538,13 +543,18 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -538,13 +543,18 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective"
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" batch size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
if isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
] ]
...@@ -586,7 +596,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -586,7 +596,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
...@@ -609,9 +626,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -609,9 +626,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
image (`torch.FloatTensor` or `PIL.Image.Image`): image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the `Image`, or tensor representing an image batch, that will be used as the starting point for the
process. process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
again.
strength (`float`, *optional*, defaults to 0.8): strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of will be used as a starting point, adding more noise to it the larger the `strength`. The number of
......
...@@ -29,7 +29,6 @@ from ...loaders import TextualInversionLoaderMixin ...@@ -29,7 +29,6 @@ from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module, is_compiled_module,
...@@ -172,7 +171,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -172,7 +171,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
...@@ -477,17 +479,12 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -477,17 +479,12 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
self, self,
prompt, prompt,
image, image,
height,
width,
callback_steps, callback_steps,
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
): ):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
): ):
...@@ -592,21 +589,26 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -592,21 +589,26 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
def check_image(self, image, prompt, prompt_embeds): def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image) image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor) image_is_tensor = isinstance(image, torch.Tensor)
image_is_np = isinstance(image, np.ndarray)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: if (
not image_is_pil
and not image_is_tensor
and not image_is_np
and not image_is_pil_list
and not image_is_tensor_list
and not image_is_np_list
):
raise TypeError( raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" "image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors"
) )
if image_is_pil: if image_is_pil:
image_batch_size = 1 image_batch_size = 1
elif image_is_tensor: else:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
image_batch_size = len(image) image_batch_size = len(image)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -633,29 +635,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -633,29 +635,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
do_classifier_free_guidance=False, do_classifier_free_guidance=False,
guess_mode=False, guess_mode=False,
): ):
if not isinstance(image, torch.Tensor): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0] image_batch_size = image.shape[0]
if image_batch_size == 1: if image_batch_size == 1:
...@@ -691,31 +671,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -691,31 +671,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
return latents return latents
def _default_height_width(self, height, width, image):
# NOTE: It is possible that a list of images have different
# dimensions for each image, so just checking the first image
# is not _exactly_ correct, but it is simple.
while isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
# override DiffusionPipeline # override DiffusionPipeline
def save_pretrained( def save_pretrained(
self, self,
...@@ -733,7 +688,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -733,7 +688,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
...@@ -760,8 +722,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -760,8 +722,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
...@@ -837,15 +799,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -837,15 +799,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, image)
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
image, image,
height,
width,
callback_steps, callback_steps,
negative_prompt, negative_prompt,
prompt_embeds, prompt_embeds,
...@@ -903,6 +861,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -903,6 +861,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode, guess_mode=guess_mode,
) )
height, width = image.shape[-2:]
elif isinstance(controlnet, MultiControlNetModel): elif isinstance(controlnet, MultiControlNetModel):
images = [] images = []
...@@ -922,6 +881,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade ...@@ -922,6 +881,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
images.append(image_) images.append(image_)
image = images image = images
height, width = image[0].shape[-2:]
else: else:
assert False assert False
......
...@@ -29,7 +29,6 @@ from ...loaders import TextualInversionLoaderMixin ...@@ -29,7 +29,6 @@ from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
...@@ -198,7 +197,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -198,7 +197,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
...@@ -503,17 +505,12 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -503,17 +505,12 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
self, self,
prompt, prompt,
image, image,
height,
width,
callback_steps, callback_steps,
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0, controlnet_conditioning_scale=1.0,
): ):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
): ):
...@@ -615,24 +612,30 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -615,24 +612,30 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
else: else:
assert False assert False
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
def check_image(self, image, prompt, prompt_embeds): def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image) image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor) image_is_tensor = isinstance(image, torch.Tensor)
image_is_np = isinstance(image, np.ndarray)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: if (
not image_is_pil
and not image_is_tensor
and not image_is_np
and not image_is_pil_list
and not image_is_tensor_list
and not image_is_np_list
):
raise TypeError( raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" "image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors"
) )
if image_is_pil: if image_is_pil:
image_batch_size = 1 image_batch_size = 1
elif image_is_tensor: else:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
image_batch_size = len(image) image_batch_size = len(image)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -660,29 +663,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -660,29 +663,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
do_classifier_free_guidance=False, do_classifier_free_guidance=False,
guess_mode=False, guess_mode=False,
): ):
if not isinstance(image, torch.Tensor): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0] image_batch_size = image.shape[0]
if image_batch_size == 1: if image_batch_size == 1:
...@@ -720,13 +701,18 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -720,13 +701,18 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
if isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
] ]
...@@ -763,31 +749,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -763,31 +749,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
return latents return latents
def _default_height_width(self, height, width, image):
# NOTE: It is possible that a list of images have different
# dimensions for each image, so just checking the first image
# is not _exactly_ correct, but it is simple.
while isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
# override DiffusionPipeline # override DiffusionPipeline
def save_pretrained( def save_pretrained(
self, self,
...@@ -805,9 +766,21 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -805,9 +766,21 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
control_image: Union[ control_image: Union[
torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image] torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None, ] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
...@@ -836,8 +809,12 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -836,8 +809,12 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accpet
image latents as `image`, if passing latents directly, it will not be encoded again.
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
...@@ -914,15 +891,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -914,15 +891,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, image)
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
control_image, control_image,
height,
width,
callback_steps, callback_steps,
negative_prompt, negative_prompt,
prompt_embeds, prompt_embeds,
...@@ -966,10 +938,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi ...@@ -966,10 +938,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
# 4. Prepare image, and controlnet_conditioning_image # 4. Prepare image
image = prepare_image(image) image = self.image_processor.preprocess(image).to(dtype=torch.float32)
# 5. Prepare image # 5. Prepare controlnet_conditioning_image
if isinstance(controlnet, ControlNetModel): if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image( control_image = self.prepare_control_image(
image=control_image, image=control_image,
......
...@@ -30,7 +30,6 @@ from ...loaders import TextualInversionLoaderMixin ...@@ -30,7 +30,6 @@ from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module, is_compiled_module,
...@@ -316,6 +315,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -316,6 +315,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
...@@ -742,24 +744,30 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -742,24 +744,30 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
else: else:
assert False assert False
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
def check_image(self, image, prompt, prompt_embeds): def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image) image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor) image_is_tensor = isinstance(image, torch.Tensor)
image_is_np = isinstance(image, np.ndarray)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: if (
not image_is_pil
and not image_is_tensor
and not image_is_np
and not image_is_pil_list
and not image_is_tensor_list
and not image_is_np_list
):
raise TypeError( raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" "image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors"
) )
if image_is_pil: if image_is_pil:
image_batch_size = 1 image_batch_size = 1
elif image_is_tensor: else:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
image_batch_size = len(image) image_batch_size = len(image)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
...@@ -787,29 +795,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -787,29 +795,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
do_classifier_free_guidance=False, do_classifier_free_guidance=False,
guess_mode=False, guess_mode=False,
): ):
if not isinstance(image, torch.Tensor): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0] image_batch_size = image.shape[0]
if image_batch_size == 1: if image_batch_size == 1:
...@@ -983,7 +969,12 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi ...@@ -983,7 +969,12 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
image: Union[torch.Tensor, PIL.Image.Image] = None, image: Union[torch.Tensor, PIL.Image.Image] = None,
mask_image: Union[torch.Tensor, PIL.Image.Image] = None, mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
control_image: Union[ control_image: Union[
torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image] torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None, ] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -30,6 +31,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -30,6 +31,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
......
...@@ -40,6 +40,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -40,6 +40,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
...@@ -549,6 +554,11 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -549,6 +554,11 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
batch_size = image.shape[0] batch_size = image.shape[0]
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
...@@ -599,7 +609,14 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -599,7 +609,14 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
source_prompt: Union[str, List[str]], source_prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
...@@ -619,9 +636,10 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -619,9 +636,10 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation. The prompt or prompts to guide the image generation.
image (`torch.FloatTensor` or `PIL.Image.Image`): image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the `Image`, or tensor representing an image batch, that will be used as the starting point for the
process. process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
again.
strength (`float`, *optional*, defaults to 0.8): strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of will be used as a starting point, adding more noise to it the larger the `strength`. The number of
...@@ -699,7 +717,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): ...@@ -699,7 +717,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
) )
# 4. Preprocess image # 4. Preprocess image
image = preprocess(image) image = self.image_processor.preprocess(image)
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -33,6 +34,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -33,6 +34,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64
def preprocess(image): def preprocess(image):
warnings.warn(
(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead"
),
FutureWarning,
)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
......
...@@ -37,6 +37,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -37,6 +37,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
...@@ -423,13 +428,18 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -423,13 +428,18 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
if isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
] ]
...@@ -474,6 +484,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -474,6 +484,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
width, height = image[0].size width, height = image[0].size
elif isinstance(image[0], np.ndarray):
width, height = image[0].shape[:-1]
else: else:
height, width = image[0].shape[-2:] height, width = image[0].shape[-2:]
...@@ -512,7 +524,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -512,7 +524,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
depth_map: Optional[torch.FloatTensor] = None, depth_map: Optional[torch.FloatTensor] = None,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
...@@ -535,9 +554,12 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -535,9 +554,12 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
image (`torch.FloatTensor` or `PIL.Image.Image`): image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the `Image`, or tensor representing an image batch, that will be used as the starting point for the
process. process. Can accept image latents as `image` only if `depth_map` is not `None`.
depth_map (`torch.FloatTensor`, *optional*):
depth prediction that will be used as additional conditioning for the image generation process. If not
defined, it will automatically predicts the depth via `self.depth_estimator`.
strength (`float`, *optional*, defaults to 0.8): strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of will be used as a starting point, adding more noise to it the larger the `strength`. The number of
...@@ -664,7 +686,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -664,7 +686,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
) )
# 5. Preprocess image # 5. Preprocess image
image = preprocess(image) image = self.image_processor.preprocess(image)
# 6. Set timesteps # 6. Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -159,6 +159,11 @@ def kl_divergence(hidden_states): ...@@ -159,6 +159,11 @@ def kl_divergence(hidden_states):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
...@@ -799,6 +804,10 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -799,6 +804,10 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
if image.shape[1] == 4:
latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
...@@ -806,7 +815,9 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -806,7 +815,9 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
) )
if isinstance(generator, list): if isinstance(generator, list):
latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)] latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
else: else:
latents = self.vae.encode(image).latent_dist.sample(generator) latents = self.vae.encode(image).latent_dist.sample(generator)
......
...@@ -73,6 +73,11 @@ EXAMPLE_DOC_STRING = """ ...@@ -73,6 +73,11 @@ EXAMPLE_DOC_STRING = """
def preprocess(image): def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
...@@ -441,6 +446,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -441,6 +446,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
return prompt_embeds return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
has_nsfw_concept = None has_nsfw_concept = None
...@@ -455,6 +461,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -455,6 +461,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
) )
return image, has_nsfw_concept return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents): def decode_latents(self, latents):
warnings.warn( warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please" "The decode_latents method is deprecated and will be removed in a future version. Please"
...@@ -544,13 +551,18 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -544,13 +551,18 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators." f" size of {batch_size}. Make sure the batch size matches the length of the generators."
) )
if isinstance(generator, list): elif isinstance(generator, list):
init_latents = [ init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
] ]
...@@ -592,7 +604,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -592,7 +604,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
...@@ -615,9 +634,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -615,9 +634,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
image (`torch.FloatTensor` or `PIL.Image.Image`): image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the `Image`, or tensor representing an image batch, that will be used as the starting point for the
process. process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
again.
strength (`float`, *optional*, defaults to 0.8): strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of will be used as a starting point, adding more noise to it the larger the `strength`. The number of
......
...@@ -43,6 +43,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -43,6 +43,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
...@@ -145,7 +150,14 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -145,7 +150,14 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 100, num_inference_steps: int = 100,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
image_guidance_scale: float = 1.5, image_guidance_scale: float = 1.5,
...@@ -168,8 +180,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -168,8 +180,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
image (`PIL.Image.Image`): image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch which will be repainted according to `prompt`. `Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also
accpet image latents as `image`, if passing latents directly, it will not be encoded again.
num_inference_steps (`int`, *optional*, defaults to 100): num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
...@@ -290,8 +303,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -290,8 +303,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
) )
# 3. Preprocess image # 3. Preprocess image
image = preprocess(image) image = self.image_processor.preprocess(image)
height, width = image.shape[-2:]
# 4. set timesteps # 4. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
...@@ -308,6 +320,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -308,6 +320,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
generator, generator,
) )
height, width = image_latents.shape[-2:]
height = height * self.vae_scale_factor
width = width * self.vae_scale_factor
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents( latents = self.prepare_latents(
...@@ -746,6 +762,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -746,6 +762,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
image_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
......
...@@ -94,7 +94,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): ...@@ -94,7 +94,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
...@@ -291,7 +291,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): ...@@ -291,7 +291,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 75, num_inference_steps: int = 75,
guidance_scale: float = 9.0, guidance_scale: float = 9.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
...@@ -308,7 +315,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): ...@@ -308,7 +315,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
Args: Args:
prompt (`str` or `List[str]`): prompt (`str` or `List[str]`):
The prompt or prompts to guide the image upscaling. The prompt or prompts to guide the image upscaling.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be `Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be
either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It
will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an
...@@ -413,7 +420,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): ...@@ -413,7 +420,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
) )
# 4. Preprocess image # 4. Preprocess image
image = preprocess(image) image = self.image_processor.preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device) image = image.to(dtype=text_embeddings.dtype, device=device)
if image.shape[1] == 3: if image.shape[1] == 3:
# encode image if not in latent-space yet # encode image if not in latent-space yet
......
...@@ -177,6 +177,11 @@ EXAMPLE_INVERT_DOC_STRING = """ ...@@ -177,6 +177,11 @@ EXAMPLE_INVERT_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
...@@ -629,7 +634,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -629,7 +634,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
image,
source_embeds, source_embeds,
target_embeds, target_embeds,
callback_steps, callback_steps,
...@@ -727,6 +731,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -727,6 +731,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
if image.shape[1] == 4:
latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
...@@ -734,7 +742,9 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -734,7 +742,9 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
) )
if isinstance(generator, list): if isinstance(generator, list):
latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)] latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
latents = torch.cat(latents, dim=0) latents = torch.cat(latents, dim=0)
else: else:
latents = self.vae.encode(image).latent_dist.sample(generator) latents = self.vae.encode(image).latent_dist.sample(generator)
...@@ -804,7 +814,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -804,7 +814,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
def __call__( def __call__(
self, self,
prompt: Optional[Union[str, List[str]]] = None, prompt: Optional[Union[str, List[str]]] = None,
image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
source_embeds: torch.Tensor = None, source_embeds: torch.Tensor = None,
target_embeds: torch.Tensor = None, target_embeds: torch.Tensor = None,
height: Optional[int] = None, height: Optional[int] = None,
...@@ -905,7 +914,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -905,7 +914,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt, prompt,
image,
source_embeds, source_embeds,
target_embeds, target_embeds,
callback_steps, callback_steps,
...@@ -1085,7 +1093,14 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -1085,7 +1093,14 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
def invert( def invert(
self, self,
prompt: Optional[str] = None, prompt: Optional[str] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 1, guidance_scale: float = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
...@@ -1109,8 +1124,9 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -1109,8 +1124,9 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
image (`PIL.Image.Image`, *optional*): image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch which will be used for conditioning. `Image`, or tensor representing an image batch which will be used for conditioning. Can also accpet
image latents as `image`, if passing latents directly, it will not be encoded again.
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
...@@ -1179,7 +1195,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -1179,7 +1195,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Preprocess image # 3. Preprocess image
image = preprocess(image) image = self.image_processor.preprocess(image)
# 4. Prepare latent variables # 4. Prepare latent variables
latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator)
...@@ -1267,16 +1283,13 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -1267,16 +1283,13 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
inverted_latents = latents.detach().clone() inverted_latents = latents.detach().clone()
# 8. Post-processing # 8. Post-processing
image = self.decode_latents(latents.detach()) image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU # Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload() self.final_offload_hook.offload()
# 9. Convert to PIL.
if output_type == "pil":
image = self.image_processor.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (inverted_latents, image) return (inverted_latents, image)
......
import inspect import inspect
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
...@@ -34,6 +35,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -34,6 +35,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image): def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor): if isinstance(image, torch.Tensor):
return image return image
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
......
...@@ -40,6 +40,7 @@ class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMix ...@@ -40,6 +40,7 @@ class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMix
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -41,7 +41,9 @@ from diffusers.utils.testing_utils import ( ...@@ -41,7 +41,9 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import ( from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS,
) )
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
...@@ -99,7 +101,8 @@ class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin ...@@ -99,7 +101,8 @@ class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin
pipeline_class = StableDiffusionControlNetPipeline pipeline_class = StableDiffusionControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -38,6 +38,7 @@ from diffusers.utils.import_utils import is_xformers_available ...@@ -38,6 +38,7 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import ( from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
) )
...@@ -51,7 +52,8 @@ class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -51,7 +52,8 @@ class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipeline_class = StableDiffusionControlNetImg2ImgPipeline pipeline_class = StableDiffusionControlNetImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS.union({"control_image"})
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -40,6 +40,7 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch ...@@ -40,6 +40,7 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch
from ..pipeline_params import ( from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
) )
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
...@@ -51,7 +52,8 @@ class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -51,7 +52,8 @@ class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipeline_class = StableDiffusionControlNetInpaintPipeline pipeline_class = StableDiffusionControlNetInpaintPipeline
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset([]) image_params = frozenset({"control_image"}) # skip `image` and `mask` for now, only test for control_image
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -25,7 +25,11 @@ from diffusers import AutoencoderKL, CycleDiffusionPipeline, DDIMScheduler, UNet ...@@ -25,7 +25,11 @@ from diffusers import AutoencoderKL, CycleDiffusionPipeline, DDIMScheduler, UNet
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
...@@ -42,7 +46,8 @@ class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterM ...@@ -42,7 +46,8 @@ class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterM
} }
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"}) batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -101,6 +106,7 @@ class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterM ...@@ -101,6 +106,7 @@ class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterM
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
if str(device).startswith("mps"): if str(device).startswith("mps"):
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment