Unverified Commit 59900147 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[WIP]Vae preprocessor refactor (PR1) (#3557)



VaeImageProcessor.preprocess refactor

* refactored VaeImageProcessor 
   -  allow passing optional height and width argument to resize()
   - add convert_to_rgb
* refactored prepare_latents method for img2img pipelines so that if we pass latents directly as image input, it will not encode it again
* added a test in test_pipelines_common.py to test latents as image inputs
* refactored img2img pipelines that accept latents as image: 
   - controlnet img2img, stable diffusion img2img , instruct_pix2pix

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 1a6a647e
......@@ -30,7 +30,8 @@ class VaeImageProcessor(ConfigMixin):
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
`height` and `width` arguments from `preprocess` method
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
factor.
......@@ -38,6 +39,8 @@ class VaeImageProcessor(ConfigMixin):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1]
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
"""
config_name = CONFIG_NAME
......@@ -49,11 +52,12 @@ class VaeImageProcessor(ConfigMixin):
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = True,
do_convert_rgb: bool = False,
):
super().__init__()
@staticmethod
def numpy_to_pil(images):
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
"""
Convert a numpy image or a batch of images to a PIL image.
"""
......@@ -69,7 +73,19 @@ class VaeImageProcessor(ConfigMixin):
return pil_images
@staticmethod
def numpy_to_pt(images):
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
"""
Convert a PIL image or a list of PIL images to numpy arrays.
"""
if not isinstance(images, list):
images = [images]
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
images = np.stack(images, axis=0)
return images
@staticmethod
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
"""
Convert a numpy image to a pytorch tensor
"""
......@@ -80,7 +96,7 @@ class VaeImageProcessor(ConfigMixin):
return images
@staticmethod
def pt_to_numpy(images):
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
"""
Convert a pytorch tensor to a numpy image
"""
......@@ -101,18 +117,39 @@ class VaeImageProcessor(ConfigMixin):
"""
return (images / 2 + 0.5).clamp(0, 1)
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
@staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
"""
Converts an image to RGB format.
"""
image = image.convert("RGB")
return image
def resize(
self,
image: PIL.Image.Image,
height: Optional[int] = None,
width: Optional[int] = None,
) -> PIL.Image.Image:
"""
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
"""
w, h = images.size
w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor
images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample])
return images
if height is None:
height = image.height
if width is None:
width = image.width
width, height = (
x - x % self.config.vae_scale_factor for x in (width, height)
) # resize to integer multiple of vae_scale_factor
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
return image
def preprocess(
self,
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
height: Optional[int] = None,
width: Optional[int] = None,
) -> torch.Tensor:
"""
Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
......@@ -126,10 +163,11 @@ class VaeImageProcessor(ConfigMixin):
)
if isinstance(image[0], PIL.Image.Image):
if self.config.do_convert_rgb:
image = [self.convert_to_rgb(i) for i in image]
if self.config.do_resize:
image = [self.resize(i) for i in image]
image = [np.array(i).astype(np.float32) / 255.0 for i in image]
image = np.stack(image, axis=0) # to np
image = [self.resize(i, height, width) for i in image]
image = self.pil_to_numpy(image) # to np
image = self.numpy_to_pt(image) # to pt
elif isinstance(image[0], np.ndarray):
......@@ -146,7 +184,12 @@ class VaeImageProcessor(ConfigMixin):
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, _, height, width = image.shape
_, channel, height, width = image.shape
# don't need any preprocess if the image is latents
if channel == 4:
return image
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
......
......@@ -69,6 +69,11 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
......@@ -538,21 +543,26 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
if image.shape[1] == 4:
init_latents = image
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective"
f" batch size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
......@@ -586,7 +596,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
......@@ -609,9 +626,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
again.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
......
......@@ -29,7 +29,6 @@ from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
......@@ -172,7 +171,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
......@@ -477,17 +479,12 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
self,
prompt,
image,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
......@@ -592,21 +589,26 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_np = isinstance(image, np.ndarray)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
if (
not image_is_pil
and not image_is_tensor
and not image_is_np
and not image_is_pil_list
and not image_is_tensor_list
and not image_is_np_list
):
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors"
)
if image_is_pil:
image_batch_size = 1
elif image_is_tensor:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
else:
image_batch_size = len(image)
if prompt is not None and isinstance(prompt, str):
......@@ -633,29 +635,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
do_classifier_free_guidance=False,
guess_mode=False,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
......@@ -691,31 +671,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
latents = latents * self.scheduler.init_noise_sigma
return latents
def _default_height_width(self, height, width, image):
# NOTE: It is possible that a list of images have different
# dimensions for each image, so just checking the first image
# is not _exactly_ correct, but it is simple.
while isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
# override DiffusionPipeline
def save_pretrained(
self,
......@@ -733,7 +688,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
......@@ -760,8 +722,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
......@@ -837,15 +799,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, image)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
image,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
......@@ -903,6 +861,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = image.shape[-2:]
elif isinstance(controlnet, MultiControlNetModel):
images = []
......@@ -922,6 +881,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
images.append(image_)
image = images
height, width = image[0].shape[-2:]
else:
assert False
......
......@@ -29,7 +29,6 @@ from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
deprecate,
is_accelerate_available,
is_accelerate_version,
......@@ -198,7 +197,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
......@@ -503,17 +505,12 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
self,
prompt,
image,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
......@@ -615,24 +612,30 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
else:
assert False
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_np = isinstance(image, np.ndarray)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
if (
not image_is_pil
and not image_is_tensor
and not image_is_np
and not image_is_pil_list
and not image_is_tensor_list
and not image_is_np_list
):
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors"
)
if image_is_pil:
image_batch_size = 1
elif image_is_tensor:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
else:
image_batch_size = len(image)
if prompt is not None and isinstance(prompt, str):
......@@ -660,29 +663,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
do_classifier_free_guidance=False,
guess_mode=False,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
......@@ -720,21 +701,26 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
if image.shape[1] == 4:
init_latents = image
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
init_latents = self.vae.config.scaling_factor * init_latents
elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
......@@ -763,31 +749,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
return latents
def _default_height_width(self, height, width, image):
# NOTE: It is possible that a list of images have different
# dimensions for each image, so just checking the first image
# is not _exactly_ correct, but it is simple.
while isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
# override DiffusionPipeline
def save_pretrained(
self,
......@@ -805,9 +766,21 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
control_image: Union[
torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
height: Optional[int] = None,
width: Optional[int] = None,
......@@ -836,8 +809,12 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accpet
image latents as `image`, if passing latents directly, it will not be encoded again.
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
......@@ -914,15 +891,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, image)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
control_image,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
......@@ -966,10 +938,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare image, and controlnet_conditioning_image
image = prepare_image(image)
# 4. Prepare image
image = self.image_processor.preprocess(image).to(dtype=torch.float32)
# 5. Prepare image
# 5. Prepare controlnet_conditioning_image
if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image(
image=control_image,
......
......@@ -30,7 +30,6 @@ from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
......@@ -316,6 +315,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
......@@ -742,24 +744,30 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
else:
assert False
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_np = isinstance(image, np.ndarray)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
if (
not image_is_pil
and not image_is_tensor
and not image_is_np
and not image_is_pil_list
and not image_is_tensor_list
and not image_is_np_list
):
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors"
)
if image_is_pil:
image_batch_size = 1
elif image_is_tensor:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
else:
image_batch_size = len(image)
if prompt is not None and isinstance(prompt, str):
......@@ -787,29 +795,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
do_classifier_free_guidance=False,
guess_mode=False,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
......@@ -983,7 +969,12 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversi
image: Union[torch.Tensor, PIL.Image.Image] = None,
mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
control_image: Union[
torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
height: Optional[int] = None,
width: Optional[int] = None,
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import warnings
from typing import List, Optional, Tuple, Union
import numpy as np
......@@ -30,6 +31,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
......
......@@ -40,6 +40,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
......@@ -549,21 +554,26 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
image = image.to(device=device, dtype=dtype)
batch_size = image.shape[0]
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
if image.shape[1] == 4:
init_latents = image
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
......@@ -599,7 +609,14 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
self,
prompt: Union[str, List[str]],
source_prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
......@@ -619,9 +636,10 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
image (`torch.FloatTensor` or `PIL.Image.Image`):
image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
again.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
......@@ -699,7 +717,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
)
# 4. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
......@@ -33,6 +34,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64
def preprocess(image):
warnings.warn(
(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead"
),
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
......
......@@ -37,6 +37,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
......@@ -423,21 +428,26 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
if image.shape[1] == 4:
init_latents = image
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
......@@ -474,6 +484,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
if isinstance(image[0], PIL.Image.Image):
width, height = image[0].size
elif isinstance(image[0], np.ndarray):
width, height = image[0].shape[:-1]
else:
height, width = image[0].shape[-2:]
......@@ -512,7 +524,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
depth_map: Optional[torch.FloatTensor] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
......@@ -535,9 +554,12 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
process. Can accept image latents as `image` only if `depth_map` is not `None`.
depth_map (`torch.FloatTensor`, *optional*):
depth prediction that will be used as additional conditioning for the image generation process. If not
defined, it will automatically predicts the depth via `self.depth_estimator`.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
......@@ -664,7 +686,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
)
# 5. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)
# 6. Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......
......@@ -159,6 +159,11 @@ def kl_divergence(hidden_states):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
......@@ -799,19 +804,25 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
image = image.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if image.shape[1] == 4:
latents = image
if isinstance(generator, list):
latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]
latents = torch.cat(latents, dim=0)
else:
latents = self.vae.encode(image).latent_dist.sample(generator)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
latents = torch.cat(latents, dim=0)
else:
latents = self.vae.encode(image).latent_dist.sample(generator)
latents = self.vae.config.scaling_factor * latents
latents = self.vae.config.scaling_factor * latents
if batch_size != latents.shape[0]:
if batch_size % latents.shape[0] == 0:
......
......@@ -73,6 +73,11 @@ EXAMPLE_DOC_STRING = """
def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
......@@ -441,6 +446,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
......@@ -455,6 +461,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
......@@ -544,21 +551,26 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
if image.shape[1] == 4:
init_latents = image
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
......@@ -592,7 +604,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
......@@ -615,9 +634,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor` or `PIL.Image.Image`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded
again.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
......
......@@ -43,6 +43,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
......@@ -145,7 +150,14 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 100,
guidance_scale: float = 7.5,
image_guidance_scale: float = 1.5,
......@@ -168,8 +180,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be repainted according to `prompt`.
image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also
accpet image latents as `image`, if passing latents directly, it will not be encoded again.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
......@@ -290,8 +303,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
)
# 3. Preprocess image
image = preprocess(image)
height, width = image.shape[-2:]
image = self.image_processor.preprocess(image)
# 4. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......@@ -308,6 +320,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
generator,
)
height, width = image_latents.shape[-2:]
height = height * self.vae_scale_factor
width = width * self.vae_scale_factor
# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents(
......@@ -746,17 +762,21 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
image_latents = torch.cat(image_latents, dim=0)
if image.shape[1] == 4:
image_latents = image
else:
image_latents = self.vae.encode(image).latent_dist.mode()
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.mode()
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand image_latents for batch_size
......
......@@ -94,7 +94,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
......@@ -291,7 +291,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
......@@ -308,7 +315,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image upscaling.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be
either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It
will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an
......@@ -413,7 +420,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
)
# 4. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)
image = image.to(dtype=text_embeddings.dtype, device=device)
if image.shape[1] == 3:
# encode image if not in latent-space yet
......
......@@ -177,6 +177,11 @@ EXAMPLE_INVERT_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
......@@ -629,7 +634,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
def check_inputs(
self,
prompt,
image,
source_embeds,
target_embeds,
callback_steps,
......@@ -727,19 +731,25 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
image = image.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if image.shape[1] == 4:
latents = image
if isinstance(generator, list):
latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]
latents = torch.cat(latents, dim=0)
else:
latents = self.vae.encode(image).latent_dist.sample(generator)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = self.vae.config.scaling_factor * latents
if isinstance(generator, list):
latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
latents = torch.cat(latents, dim=0)
else:
latents = self.vae.encode(image).latent_dist.sample(generator)
latents = self.vae.config.scaling_factor * latents
if batch_size != latents.shape[0]:
if batch_size % latents.shape[0] == 0:
......@@ -804,7 +814,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
source_embeds: torch.Tensor = None,
target_embeds: torch.Tensor = None,
height: Optional[int] = None,
......@@ -905,7 +914,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
image,
source_embeds,
target_embeds,
callback_steps,
......@@ -1085,7 +1093,14 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
def invert(
self,
prompt: Optional[str] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 50,
guidance_scale: float = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
......@@ -1109,8 +1124,9 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`, *optional*):
`Image`, or tensor representing an image batch which will be used for conditioning.
image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch which will be used for conditioning. Can also accpet
image latents as `image`, if passing latents directly, it will not be encoded again.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
......@@ -1179,7 +1195,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)
# 4. Prepare latent variables
latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator)
......@@ -1267,16 +1283,13 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
inverted_latents = latents.detach().clone()
# 8. Post-processing
image = self.decode_latents(latents.detach())
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# 9. Convert to PIL.
if output_type == "pil":
image = self.image_processor.numpy_to_pil(image)
if not return_dict:
return (inverted_latents, image)
......
import inspect
import warnings
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
......@@ -34,6 +35,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
warnings.warn(
"The preprocess method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor.preprocess instead",
FutureWarning,
)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
......
......@@ -40,6 +40,7 @@ class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMix
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -41,7 +41,9 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -99,7 +101,8 @@ class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin
pipeline_class = StableDiffusionControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -38,6 +38,7 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
......@@ -51,7 +52,8 @@ class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipeline_class = StableDiffusionControlNetImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS.union({"control_image"})
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -40,6 +40,7 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -51,7 +52,8 @@ class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipeline_class = StableDiffusionControlNetInpaintPipeline
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset([])
image_params = frozenset({"control_image"}) # skip `image` and `mask` for now, only test for control_image
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -25,7 +25,11 @@ from diffusers import AutoencoderKL, CycleDiffusionPipeline, DDIMScheduler, UNet
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
......@@ -42,7 +46,8 @@ class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterM
}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......@@ -101,6 +106,7 @@ class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterM
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment