Unverified Commit b7b1a30b authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

refactor prepare_mask_and_masked_image with VaeImageProcessor (#4444)



* refactor image processor for mask
---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 7e5587a5
...@@ -24,6 +24,16 @@ from .configuration_utils import ConfigMixin, register_to_config ...@@ -24,6 +24,16 @@ from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
PipelineImageInput = Union[
PIL.Image.Image,
np.ndarray,
torch.FloatTensor,
List[PIL.Image.Image],
List[np.ndarray],
List[torch.FloatTensor],
]
class VaeImageProcessor(ConfigMixin): class VaeImageProcessor(ConfigMixin):
""" """
Image processor for VAE. Image processor for VAE.
...@@ -38,8 +48,12 @@ class VaeImageProcessor(ConfigMixin): ...@@ -38,8 +48,12 @@ class VaeImageProcessor(ConfigMixin):
Resampling filter to use when resizing the image. Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`): do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1]. Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the image to 0/1.
do_convert_rgb (`bool`, *optional*, defaults to be `False`): do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format. Whether to convert the images to RGB format.
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to grayscale format.
""" """
config_name = CONFIG_NAME config_name = CONFIG_NAME
...@@ -51,9 +65,18 @@ class VaeImageProcessor(ConfigMixin): ...@@ -51,9 +65,18 @@ class VaeImageProcessor(ConfigMixin):
vae_scale_factor: int = 8, vae_scale_factor: int = 8,
resample: str = "lanczos", resample: str = "lanczos",
do_normalize: bool = True, do_normalize: bool = True,
do_binarize: bool = False,
do_convert_rgb: bool = False, do_convert_rgb: bool = False,
do_convert_grayscale: bool = False,
): ):
super().__init__() super().__init__()
if do_convert_rgb and do_convert_grayscale:
raise ValueError(
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
)
self.config.do_convert_rgb = False
@staticmethod @staticmethod
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
...@@ -119,31 +142,84 @@ class VaeImageProcessor(ConfigMixin): ...@@ -119,31 +142,84 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod @staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
""" """
Converts an image to RGB format. Converts a PIL image to RGB format.
""" """
image = image.convert("RGB") image = image.convert("RGB")
return image return image
def resize( @staticmethod
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
"""
Converts a PIL image to grayscale format.
"""
image = image.convert("L")
return image
def get_default_height_width(
self, self,
image: PIL.Image.Image, image: [PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
) -> PIL.Image.Image: ):
""" """
Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`. This function return the height and width that are downscaled to the next integer multiple of
`vae_scale_factor`.
Args:
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
have shape `[batch, channel, height, width]`.
height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the height of `image` input.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use the width of the `image` input.
""" """
if height is None: if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
else:
height = image.shape[1]
if width is None: if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
else:
height = image.shape[2]
width, height = ( width, height = (
x - x % self.config.vae_scale_factor for x in (width, height) x - x % self.config.vae_scale_factor for x in (width, height)
) # resize to integer multiple of vae_scale_factor ) # resize to integer multiple of vae_scale_factor
return height, width
def resize(
self,
image: PIL.Image.Image,
height: Optional[int] = None,
width: Optional[int] = None,
) -> PIL.Image.Image:
"""
Resize a PIL image.
"""
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
return image return image
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
"""
create a mask
"""
image[image < 0.5] = 0
image[image >= 0.5] = 1
return image
def preprocess( def preprocess(
self, self,
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
...@@ -154,6 +230,25 @@ class VaeImageProcessor(ConfigMixin): ...@@ -154,6 +230,25 @@ class VaeImageProcessor(ConfigMixin):
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
""" """
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
if isinstance(image, torch.Tensor):
# if image is a pytorch tensor could have 2 possible shapes:
# 1. batch x height x width: we should insert the channel dimension at position 1
# 2. channnel x height x width: we should insert batch dimension at position 0,
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
image = image.unsqueeze(1)
else:
# if it is a numpy array, it could have 2 possible shapes:
# 1. batch x height x width: insert channel dimension on last position
# 2. height x width x channel: insert batch dimension on first position
if image.shape[-1] == 1:
image = np.expand_dims(image, axis=0)
else:
image = np.expand_dims(image, axis=-1)
if isinstance(image, supported_formats): if isinstance(image, supported_formats):
image = [image] image = [image]
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
...@@ -164,42 +259,47 @@ class VaeImageProcessor(ConfigMixin): ...@@ -164,42 +259,47 @@ class VaeImageProcessor(ConfigMixin):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
if self.config.do_convert_rgb: if self.config.do_convert_rgb:
image = [self.convert_to_rgb(i) for i in image] image = [self.convert_to_rgb(i) for i in image]
elif self.config.do_convert_grayscale:
image = [self.convert_to_grayscale(i) for i in image]
if self.config.do_resize: if self.config.do_resize:
height, width = self.get_default_height_width(image[0], height, width)
image = [self.resize(i, height, width) for i in image] image = [self.resize(i, height, width) for i in image]
image = self.pil_to_numpy(image) # to np image = self.pil_to_numpy(image) # to np
image = self.numpy_to_pt(image) # to pt image = self.numpy_to_pt(image) # to pt
elif isinstance(image[0], np.ndarray): elif isinstance(image[0], np.ndarray):
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
image = self.numpy_to_pt(image) image = self.numpy_to_pt(image)
_, _, height, width = image.shape
if self.config.do_resize and ( height, width = self.get_default_height_width(image, height, width)
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width):
):
raise ValueError( raise ValueError(
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}" f"Currently we only support resizing for PIL image - please resize your numpy array to be {height} and {width}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
) )
elif isinstance(image[0], torch.Tensor): elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, channel, height, width = image.shape
if self.config.do_convert_grayscale and image.ndim == 3:
image = image.unsqueeze(1)
channel = image.shape[1]
# don't need any preprocess if the image is latents # don't need any preprocess if the image is latents
if channel == 4: if channel == 4:
return image return image
if self.config.do_resize and ( height, width = self.get_default_height_width(image, height, width)
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width):
):
raise ValueError( raise ValueError(
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}" f"Currently we only support resizing for PIL image - please resize your torch tensor to be {height} and {width}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
) )
# expected range [0,1], normalize to [-1,1] # expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize do_normalize = self.config.do_normalize
if image.min() < 0: if image.min() < 0 and do_normalize:
warnings.warn( warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
...@@ -210,6 +310,9 @@ class VaeImageProcessor(ConfigMixin): ...@@ -210,6 +310,9 @@ class VaeImageProcessor(ConfigMixin):
if do_normalize: if do_normalize:
image = self.normalize(image) image = self.normalize(image)
if self.config.do_binarize:
image = self.binarize(image)
return image return image
def postprocess( def postprocess(
......
...@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, XLMRobertaTokenizer ...@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -567,14 +567,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -567,14 +567,7 @@ class AltDiffusionImg2ImgPipeline(
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
...@@ -597,7 +590,10 @@ class AltDiffusionImg2ImgPipeline( ...@@ -597,7 +590,10 @@ class AltDiffusionImg2ImgPipeline(
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image` or tensor representing an image batch to be used as the starting point. Can also accept image `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again. latents as `image`, but if passing latents directly it is not encoded again.
strength (`float`, *optional*, defaults to 0.8): strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
......
...@@ -23,7 +23,7 @@ import torch ...@@ -23,7 +23,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -678,14 +678,7 @@ class StableDiffusionControlNetPipeline( ...@@ -678,14 +678,7 @@ class StableDiffusionControlNetPipeline(
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
......
...@@ -23,7 +23,7 @@ import torch ...@@ -23,7 +23,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -750,22 +750,8 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -750,22 +750,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor, control_image: PipelineImageInput = None,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
control_image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
strength: float = 0.8, strength: float = 0.8,
......
...@@ -24,11 +24,12 @@ import torch ...@@ -24,11 +24,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module, is_compiled_module,
...@@ -133,7 +134,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False ...@@ -133,7 +134,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``. dimensions: ``batch x channels x height x width``.
""" """
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None: if image is None:
raise ValueError("`image` input cannot be undefined.") raise ValueError("`image` input cannot be undefined.")
...@@ -316,6 +322,9 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -316,6 +322,9 @@ class StableDiffusionControlNetInpaintPipeline(
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.control_image_processor = VaeImageProcessor( self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
) )
...@@ -615,7 +624,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -615,7 +624,7 @@ class StableDiffusionControlNetInpaintPipeline(
control_guidance_start=0.0, control_guidance_start=0.0,
control_guidance_end=1.0, control_guidance_end=1.0,
): ):
if height % 8 != 0 or width % 8 != 0: if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or ( if (callback_steps is None) or (
...@@ -863,31 +872,6 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -863,31 +872,6 @@ class StableDiffusionControlNetInpaintPipeline(
return outputs return outputs
def _default_height_width(self, height, width, image):
# NOTE: It is possible that a list of images have different
# dimensions for each image, so just checking the first image
# is not _exactly_ correct, but it is simple.
while isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
def prepare_mask_latents( def prepare_mask_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
...@@ -950,16 +934,9 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -950,16 +934,9 @@ class StableDiffusionControlNetInpaintPipeline(
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[torch.Tensor, PIL.Image.Image] = None, image: PipelineImageInput = None,
mask_image: Union[torch.Tensor, PIL.Image.Image] = None, mask_image: PipelineImageInput = None,
control_image: Union[ control_image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
strength: float = 1.0, strength: float = 1.0,
...@@ -989,14 +966,29 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -989,14 +966,29 @@ class StableDiffusionControlNetInpaintPipeline(
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
if passing latents directly it is not encoded again.
mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
1)`, or `(H, W)`.
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. The
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed,
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If `image` is resized according to them. If multiple ControlNets are specified in init, images must be
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are passed as a list such that each element of the list can be correctly batched for input to a single
specified in init, images must be passed as a list such that each element of the list can be correctly controlnet.
batched for input to a single controlnet.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
...@@ -1080,9 +1072,6 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1080,9 +1072,6 @@ class StableDiffusionControlNetInpaintPipeline(
""" """
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, image)
# align format for control guidance # align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start] control_guidance_start = len(control_guidance_end) * [control_guidance_start]
...@@ -1184,9 +1173,13 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -1184,9 +1173,13 @@ class StableDiffusionControlNetInpaintPipeline(
assert False assert False
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
mask, masked_image, init_image = prepare_mask_and_masked_image( init_image = self.image_processor.preprocess(image, height=height, width=width)
image, mask_image, height, width, return_image=True init_image = init_image.to(dtype=torch.float32)
)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = init_image * (mask < 0.5)
_, _, height, width = init_image.shape
# 5. Prepare timesteps # 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
...@@ -25,7 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz ...@@ -25,7 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from diffusers.utils.import_utils import is_invisible_watermark_available from diffusers.utils.import_utils import is_invisible_watermark_available
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
...@@ -755,14 +755,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -755,14 +755,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
......
...@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer ...@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
...@@ -578,14 +578,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -578,14 +578,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
source_prompt: Union[str, List[str]], source_prompt: Union[str, List[str]],
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
......
...@@ -24,7 +24,7 @@ from packaging import version ...@@ -24,7 +24,7 @@ from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -499,14 +499,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -499,14 +499,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
depth_map: Optional[torch.FloatTensor] = None, depth_map: Optional[torch.FloatTensor] = None,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
......
...@@ -23,7 +23,7 @@ from packaging import version ...@@ -23,7 +23,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -573,14 +573,7 @@ class StableDiffusionImg2ImgPipeline( ...@@ -573,14 +573,7 @@ class StableDiffusionImg2ImgPipeline(
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
strength: float = 0.8, strength: float = 0.8,
num_inference_steps: Optional[int] = 50, num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5, guidance_scale: Optional[float] = 7.5,
...@@ -603,7 +596,10 @@ class StableDiffusionImg2ImgPipeline( ...@@ -603,7 +596,10 @@ class StableDiffusionImg2ImgPipeline(
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image` or tensor representing an image batch to be used as the starting point. Can also accept image `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again. latents as `image`, but if passing latents directly it is not encoded again.
strength (`float`, *optional*, defaults to 0.8): strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
......
...@@ -22,7 +22,7 @@ from packaging import version ...@@ -22,7 +22,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -63,7 +63,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool ...@@ -63,7 +63,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``. dimensions: ``batch x channels x height x width``.
""" """
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None: if image is None:
raise ValueError("`image` input cannot be undefined.") raise ValueError("`image` input cannot be undefined.")
...@@ -280,6 +285,9 @@ class StableDiffusionInpaintPipeline( ...@@ -280,6 +285,9 @@ class StableDiffusionInpaintPipeline(
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
...@@ -679,8 +687,8 @@ class StableDiffusionInpaintPipeline( ...@@ -679,8 +687,8 @@ class StableDiffusionInpaintPipeline(
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: PipelineImageInput = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, mask_image: PipelineImageInput = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
strength: float = 1.0, strength: float = 1.0,
...@@ -705,14 +713,20 @@ class StableDiffusionInpaintPipeline( ...@@ -705,14 +713,20 @@ class StableDiffusionInpaintPipeline(
Args: Args:
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`PIL.Image.Image`): image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
out with `mask_image` and repainted according to `prompt`). be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
mask_image (`PIL.Image.Image`): tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
`Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
(luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the if passing latents directly it is not encoded again.
expected shape would be `(B, H, W, 1)`. mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
1)`, or `(H, W)`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
...@@ -865,9 +879,14 @@ class StableDiffusionInpaintPipeline( ...@@ -865,9 +879,14 @@ class StableDiffusionInpaintPipeline(
is_strength_max = strength == 1.0 is_strength_max = strength == 1.0
# 5. Preprocess mask and image # 5. Preprocess mask and image
mask, masked_image, init_image = prepare_mask_and_masked_image(
image, mask_image, height, width, return_image=True init_image = self.image_processor.preprocess(image, height=height, width=width)
) init_image = init_image.to(dtype=torch.float32)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = init_image * (mask < 0.5)
mask_condition = mask.clone() mask_condition = mask.clone()
# 6. Prepare latent variables # 6. Prepare latent variables
......
...@@ -21,7 +21,7 @@ import PIL ...@@ -21,7 +21,7 @@ import PIL
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
...@@ -147,14 +147,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -147,14 +147,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 100, num_inference_steps: int = 100,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
image_guidance_scale: float = 1.5, image_guidance_scale: float = 1.5,
......
...@@ -21,7 +21,7 @@ import torch ...@@ -21,7 +21,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler from ...schedulers import EulerDiscreteScheduler
from ...utils import logging, randn_tensor from ...utils import logging, randn_tensor
...@@ -257,14 +257,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): ...@@ -257,14 +257,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 75, num_inference_steps: int = 75,
guidance_scale: float = 9.0, guidance_scale: float = 9.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
......
...@@ -29,7 +29,7 @@ from transformers import ( ...@@ -29,7 +29,7 @@ from transformers import (
CLIPTokenizer, CLIPTokenizer,
) )
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
...@@ -1066,14 +1066,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -1066,14 +1066,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
def invert( def invert(
self, self,
prompt: Optional[str] = None, prompt: Optional[str] = None,
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 1, guidance_scale: float = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
......
...@@ -21,7 +21,7 @@ import PIL ...@@ -21,7 +21,7 @@ import PIL
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
...@@ -496,14 +496,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi ...@@ -496,14 +496,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 75, num_inference_steps: int = 75,
guidance_scale: float = 9.0, guidance_scale: float = 9.0,
noise_level: int = 20, noise_level: int = 20,
......
...@@ -16,12 +16,11 @@ import inspect ...@@ -16,12 +16,11 @@ import inspect
import os import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
...@@ -656,14 +655,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -656,14 +655,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
strength: float = 0.3, strength: float = 0.3,
num_inference_steps: int = 50, num_inference_steps: int = 50,
denoising_start: Optional[float] = None, denoising_start: Optional[float] = None,
......
...@@ -21,7 +21,7 @@ import PIL ...@@ -21,7 +21,7 @@ import PIL
import torch import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
...@@ -32,6 +32,7 @@ from ...models.attention_processor import ( ...@@ -32,6 +32,7 @@ from ...models.attention_processor import (
) )
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_invisible_watermark_available, is_invisible_watermark_available,
...@@ -140,6 +141,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool ...@@ -140,6 +141,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
""" """
# checkpoint. TOD(Yiyi) - need to clean this up later # checkpoint. TOD(Yiyi) - need to clean this up later
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None: if image is None:
raise ValueError("`image` input cannot be undefined.") raise ValueError("`image` input cannot be undefined.")
...@@ -290,6 +297,9 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS ...@@ -290,6 +297,9 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
...@@ -857,8 +867,8 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS ...@@ -857,8 +867,8 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: PipelineImageInput = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, mask_image: PipelineImageInput = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
strength: float = 1.0, strength: float = 1.0,
...@@ -1098,9 +1108,16 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS ...@@ -1098,9 +1108,16 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
is_strength_max = strength == 1.0 is_strength_max = strength == 1.0
# 5. Preprocess mask and image # 5. Preprocess mask and image
mask, masked_image, init_image = prepare_mask_and_masked_image( init_image = self.image_processor.preprocess(image, height=height, width=width)
image, mask_image, height, width, return_image=True init_image = init_image.to(dtype=torch.float32)
)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
if init_image.shape[1] == 4:
# if images are in latent space, we can't mask it
masked_image = None
else:
masked_image = init_image * (mask < 0.5)
# 6. Prepare latent variables # 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels num_channels_latents = self.vae.config.latent_channels
......
...@@ -15,12 +15,11 @@ ...@@ -15,12 +15,11 @@
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
...@@ -587,14 +586,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile ...@@ -587,14 +586,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
image: Union[ image: PipelineImageInput = None,
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
num_inference_steps: int = 100, num_inference_steps: int = 100,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
image_guidance_scale: float = 1.5, image_guidance_scale: float = 1.5,
......
...@@ -34,6 +34,17 @@ class ImageProcessorTest(unittest.TestCase): ...@@ -34,6 +34,17 @@ class ImageProcessorTest(unittest.TestCase):
return sample return sample
@property
def dummy_mask(self):
batch_size = 1
num_channels = 1
height = 8
width = 8
sample = torch.rand((batch_size, num_channels, height, width))
return sample
def to_np(self, image): def to_np(self, image):
if isinstance(image[0], PIL.Image.Image): if isinstance(image[0], PIL.Image.Image):
return np.stack([np.array(i) for i in image], axis=0) return np.stack([np.array(i) for i in image], axis=0)
...@@ -133,17 +144,144 @@ class ImageProcessorTest(unittest.TestCase): ...@@ -133,17 +144,144 @@ class ImageProcessorTest(unittest.TestCase):
) )
input_np_4d = self.to_np(self.dummy_sample) input_np_4d = self.to_np(self.dummy_sample)
list(input_np_4d) input_np_list = list(input_np_4d)
out_np_4d = image_processor.postprocess( out_np_4d = image_processor.postprocess(
image_processor.preprocess(input_pt_4d), image_processor.preprocess(input_np_4d),
output_type="np", output_type="np",
) )
out_np_list = image_processor.postprocess( out_np_list = image_processor.postprocess(
image_processor.preprocess(input_pt_list), image_processor.preprocess(input_np_list),
output_type="np", output_type="np",
) )
assert np.abs(out_pt_4d - out_pt_list).max() < 1e-6 assert np.abs(out_pt_4d - out_pt_list).max() < 1e-6
assert np.abs(out_np_4d - out_np_list).max() < 1e-6 assert np.abs(out_np_4d - out_np_list).max() < 1e-6
def test_preprocess_input_mask_3d(self):
image_processor = VaeImageProcessor(
do_resize=False, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
input_pt_4d = self.dummy_mask
input_pt_3d = input_pt_4d.squeeze(0)
input_pt_2d = input_pt_3d.squeeze(0)
out_pt_4d = image_processor.postprocess(
image_processor.preprocess(input_pt_4d),
output_type="np",
)
out_pt_3d = image_processor.postprocess(
image_processor.preprocess(input_pt_3d),
output_type="np",
)
out_pt_2d = image_processor.postprocess(
image_processor.preprocess(input_pt_2d),
output_type="np",
)
input_np_4d = self.to_np(self.dummy_mask)
input_np_3d = input_np_4d.squeeze(0)
input_np_3d_1 = input_np_4d.squeeze(-1)
input_np_2d = input_np_3d.squeeze(-1)
out_np_4d = image_processor.postprocess(
image_processor.preprocess(input_np_4d),
output_type="np",
)
out_np_3d = image_processor.postprocess(
image_processor.preprocess(input_np_3d),
output_type="np",
)
out_np_3d_1 = image_processor.postprocess(
image_processor.preprocess(input_np_3d_1),
output_type="np",
)
out_np_2d = image_processor.postprocess(
image_processor.preprocess(input_np_2d),
output_type="np",
)
assert np.abs(out_pt_4d - out_pt_3d).max() == 0
assert np.abs(out_pt_4d - out_pt_2d).max() == 0
assert np.abs(out_np_4d - out_np_3d).max() == 0
assert np.abs(out_np_4d - out_np_3d_1).max() == 0
assert np.abs(out_np_4d - out_np_2d).max() == 0
def test_preprocess_input_mask_list(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False, do_convert_grayscale=True)
input_pt_4d = self.dummy_mask
input_pt_3d = input_pt_4d.squeeze(0)
input_pt_2d = input_pt_3d.squeeze(0)
inputs_pt = [input_pt_4d, input_pt_3d, input_pt_2d]
inputs_pt_list = [[input_pt] for input_pt in inputs_pt]
for input_pt, input_pt_list in zip(inputs_pt, inputs_pt_list):
out_pt = image_processor.postprocess(
image_processor.preprocess(input_pt),
output_type="np",
)
out_pt_list = image_processor.postprocess(
image_processor.preprocess(input_pt_list),
output_type="np",
)
assert np.abs(out_pt - out_pt_list).max() < 1e-6
input_np_4d = self.to_np(self.dummy_mask)
input_np_3d = input_np_4d.squeeze(0)
input_np_2d = input_np_3d.squeeze(-1)
inputs_np = [input_np_4d, input_np_3d, input_np_2d]
inputs_np_list = [[input_np] for input_np in inputs_np]
for input_np, input_np_list in zip(inputs_np, inputs_np_list):
out_np = image_processor.postprocess(
image_processor.preprocess(input_np),
output_type="np",
)
out_np_list = image_processor.postprocess(
image_processor.preprocess(input_np_list),
output_type="np",
)
assert np.abs(out_np - out_np_list).max() < 1e-6
def test_preprocess_input_mask_3d_batch(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False, do_convert_grayscale=True)
# create a dummy mask input with batch_size 2
dummy_mask_batch = torch.cat([self.dummy_mask] * 2, axis=0)
# squeeze out the channel dimension
input_pt_3d = dummy_mask_batch.squeeze(1)
input_np_3d = self.to_np(dummy_mask_batch).squeeze(-1)
input_pt_3d_list = list(input_pt_3d)
input_np_3d_list = list(input_np_3d)
out_pt_3d = image_processor.postprocess(
image_processor.preprocess(input_pt_3d),
output_type="np",
)
out_pt_3d_list = image_processor.postprocess(
image_processor.preprocess(input_pt_3d_list),
output_type="np",
)
assert np.abs(out_pt_3d - out_pt_3d_list).max() < 1e-6
out_np_3d = image_processor.postprocess(
image_processor.preprocess(input_np_3d),
output_type="np",
)
out_np_3d_list = image_processor.postprocess(
image_processor.preprocess(input_np_3d_list),
output_type="np",
)
assert np.abs(out_np_3d - out_np_3d_list).max() < 1e-6
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment