Unverified Commit b7b1a30b authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

refactor prepare_mask_and_masked_image with VaeImageProcessor (#4444)



* refactor image processor for mask
---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 7e5587a5
......@@ -24,6 +24,16 @@ from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
PipelineImageInput = Union[
PIL.Image.Image,
np.ndarray,
torch.FloatTensor,
List[PIL.Image.Image],
List[np.ndarray],
List[torch.FloatTensor],
]
class VaeImageProcessor(ConfigMixin):
"""
Image processor for VAE.
......@@ -38,8 +48,12 @@ class VaeImageProcessor(ConfigMixin):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the image to 0/1.
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to grayscale format.
"""
config_name = CONFIG_NAME
......@@ -51,9 +65,18 @@ class VaeImageProcessor(ConfigMixin):
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_rgb: bool = False,
do_convert_grayscale: bool = False,
):
super().__init__()
if do_convert_rgb and do_convert_grayscale:
raise ValueError(
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
)
self.config.do_convert_rgb = False
@staticmethod
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
......@@ -119,31 +142,84 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
"""
Converts an image to RGB format.
Converts a PIL image to RGB format.
"""
image = image.convert("RGB")
return image
def resize(
@staticmethod
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
"""
Converts a PIL image to grayscale format.
"""
image = image.convert("L")
return image
def get_default_height_width(
self,
image: PIL.Image.Image,
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
) -> PIL.Image.Image:
):
"""
Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
This function return the height and width that are downscaled to the next integer multiple of
`vae_scale_factor`.
Args:
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
have shape `[batch, channel, height, width]`.
height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the height of `image` input.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use the width of the `image` input.
"""
if height is None:
height = image.height
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
else:
height = image.shape[1]
if width is None:
width = image.width
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
else:
height = image.shape[2]
width, height = (
x - x % self.config.vae_scale_factor for x in (width, height)
) # resize to integer multiple of vae_scale_factor
return height, width
def resize(
self,
image: PIL.Image.Image,
height: Optional[int] = None,
width: Optional[int] = None,
) -> PIL.Image.Image:
"""
Resize a PIL image.
"""
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
return image
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
"""
create a mask
"""
image[image < 0.5] = 0
image[image >= 0.5] = 1
return image
def preprocess(
self,
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
......@@ -154,6 +230,25 @@ class VaeImageProcessor(ConfigMixin):
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
if isinstance(image, torch.Tensor):
# if image is a pytorch tensor could have 2 possible shapes:
# 1. batch x height x width: we should insert the channel dimension at position 1
# 2. channnel x height x width: we should insert batch dimension at position 0,
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
image = image.unsqueeze(1)
else:
# if it is a numpy array, it could have 2 possible shapes:
# 1. batch x height x width: insert channel dimension on last position
# 2. height x width x channel: insert batch dimension on first position
if image.shape[-1] == 1:
image = np.expand_dims(image, axis=0)
else:
image = np.expand_dims(image, axis=-1)
if isinstance(image, supported_formats):
image = [image]
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
......@@ -164,42 +259,47 @@ class VaeImageProcessor(ConfigMixin):
if isinstance(image[0], PIL.Image.Image):
if self.config.do_convert_rgb:
image = [self.convert_to_rgb(i) for i in image]
elif self.config.do_convert_grayscale:
image = [self.convert_to_grayscale(i) for i in image]
if self.config.do_resize:
height, width = self.get_default_height_width(image[0], height, width)
image = [self.resize(i, height, width) for i in image]
image = self.pil_to_numpy(image) # to np
image = self.numpy_to_pt(image) # to pt
elif isinstance(image[0], np.ndarray):
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
image = self.numpy_to_pt(image)
_, _, height, width = image.shape
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
height, width = self.get_default_height_width(image, height, width)
if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
f"Currently we only support resizing for PIL image - please resize your numpy array to be {height} and {width}"
f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, channel, height, width = image.shape
if self.config.do_convert_grayscale and image.ndim == 3:
image = image.unsqueeze(1)
channel = image.shape[1]
# don't need any preprocess if the image is latents
if channel == 4:
return image
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
height, width = self.get_default_height_width(image, height, width)
if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
f"Currently we only support resizing for PIL image - please resize your torch tensor to be {height} and {width}"
f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
# expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize
if image.min() < 0:
if image.min() < 0 and do_normalize:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
......@@ -210,6 +310,9 @@ class VaeImageProcessor(ConfigMixin):
if do_normalize:
image = self.normalize(image)
if self.config.do_binarize:
image = self.binarize(image)
return image
def postprocess(
......
......@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
......@@ -567,14 +567,7 @@ class AltDiffusionImg2ImgPipeline(
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
......@@ -597,7 +590,10 @@ class AltDiffusionImg2ImgPipeline(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image` or tensor representing an image batch to be used as the starting point. Can also accept image
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
......
......@@ -23,7 +23,7 @@ import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
......@@ -678,14 +678,7 @@ class StableDiffusionControlNetPipeline(
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
......
......@@ -23,7 +23,7 @@ import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
......@@ -750,22 +750,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
control_image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.8,
......
......@@ -24,11 +24,12 @@ import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
......@@ -133,7 +134,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None:
raise ValueError("`image` input cannot be undefined.")
......@@ -316,6 +322,9 @@ class StableDiffusionControlNetInpaintPipeline(
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
)
......@@ -615,7 +624,7 @@ class StableDiffusionControlNetInpaintPipeline(
control_guidance_start=0.0,
control_guidance_end=1.0,
):
if height % 8 != 0 or width % 8 != 0:
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
......@@ -863,31 +872,6 @@ class StableDiffusionControlNetInpaintPipeline(
return outputs
def _default_height_width(self, height, width, image):
# NOTE: It is possible that a list of images have different
# dimensions for each image, so just checking the first image
# is not _exactly_ correct, but it is simple.
while isinstance(image, list):
image = image[0]
if height is None:
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
height = (height // 8) * 8 # round down to nearest multiple of 8
if width is None:
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
width = (width // 8) * 8 # round down to nearest multiple of 8
return height, width
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
def prepare_mask_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
......@@ -950,16 +934,9 @@ class StableDiffusionControlNetInpaintPipeline(
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.Tensor, PIL.Image.Image] = None,
mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
control_image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.0,
......@@ -989,14 +966,29 @@ class StableDiffusionControlNetInpaintPipeline(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
if passing latents directly it is not encoded again.
mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
1)`, or `(H, W)`.
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
specified in init, images must be passed as a list such that each element of the list can be correctly
batched for input to a single controlnet.
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. The
dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed,
`image` is resized according to them. If multiple ControlNets are specified in init, images must be
passed as a list such that each element of the list can be correctly batched for input to a single
controlnet.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
......@@ -1080,9 +1072,6 @@ class StableDiffusionControlNetInpaintPipeline(
"""
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, image)
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
......@@ -1184,9 +1173,13 @@ class StableDiffusionControlNetInpaintPipeline(
assert False
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
mask, masked_image, init_image = prepare_mask_and_masked_image(
image, mask_image, height, width, return_image=True
)
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = init_image * (mask < 0.5)
_, _, height, width = init_image.shape
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
......
......@@ -25,7 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from diffusers.utils.import_utils import is_invisible_watermark_available
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
......@@ -755,14 +755,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
......
......@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler
......@@ -578,14 +578,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self,
prompt: Union[str, List[str]],
source_prompt: Union[str, List[str]],
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
......
......@@ -24,7 +24,7 @@ from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
......@@ -499,14 +499,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
depth_map: Optional[torch.FloatTensor] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
......
......@@ -23,7 +23,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
......@@ -573,14 +573,7 @@ class StableDiffusionImg2ImgPipeline(
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
......@@ -603,7 +596,10 @@ class StableDiffusionImg2ImgPipeline(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image` or tensor representing an image batch to be used as the starting point. Can also accept image
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
......
......@@ -22,7 +22,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
......@@ -63,7 +63,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None:
raise ValueError("`image` input cannot be undefined.")
......@@ -280,6 +285,9 @@ class StableDiffusionInpaintPipeline(
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
......@@ -679,8 +687,8 @@ class StableDiffusionInpaintPipeline(
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.0,
......@@ -705,14 +713,20 @@ class StableDiffusionInpaintPipeline(
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`PIL.Image.Image`):
`Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked
out with `mask_image` and repainted according to `prompt`).
mask_image (`PIL.Image.Image`):
`Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted
while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel
(luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the
expected shape would be `(B, H, W, 1)`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
if passing latents directly it is not encoded again.
mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
1)`, or `(H, W)`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
......@@ -865,9 +879,14 @@ class StableDiffusionInpaintPipeline(
is_strength_max = strength == 1.0
# 5. Preprocess mask and image
mask, masked_image, init_image = prepare_mask_and_masked_image(
image, mask_image, height, width, return_image=True
)
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = init_image * (mask < 0.5)
mask_condition = mask.clone()
# 6. Prepare latent variables
......
......@@ -21,7 +21,7 @@ import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
......@@ -147,14 +147,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
num_inference_steps: int = 100,
guidance_scale: float = 7.5,
image_guidance_scale: float = 1.5,
......
......@@ -21,7 +21,7 @@ import torch
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import logging, randn_tensor
......@@ -257,14 +257,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline):
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
......
......@@ -29,7 +29,7 @@ from transformers import (
CLIPTokenizer,
)
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention
......@@ -1066,14 +1066,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
def invert(
self,
prompt: Optional[str] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
num_inference_steps: int = 50,
guidance_scale: float = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
......
......@@ -21,7 +21,7 @@ import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
......@@ -496,14 +496,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
......
......@@ -16,12 +16,11 @@ import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
......@@ -656,14 +655,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
strength: float = 0.3,
num_inference_steps: int = 50,
denoising_start: Optional[float] = None,
......
......@@ -21,7 +21,7 @@ import PIL
import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
......@@ -32,6 +32,7 @@ from ...models.attention_processor import (
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
is_accelerate_available,
is_accelerate_version,
is_invisible_watermark_available,
......@@ -140,6 +141,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
"""
# checkpoint. TOD(Yiyi) - need to clean this up later
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
deprecate(
"prepare_mask_and_masked_image",
"0.30.0",
deprecation_message,
)
if image is None:
raise ValueError("`image` input cannot be undefined.")
......@@ -290,6 +297,9 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
......@@ -857,8 +867,8 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.0,
......@@ -1098,9 +1108,16 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
is_strength_max = strength == 1.0
# 5. Preprocess mask and image
mask, masked_image, init_image = prepare_mask_and_masked_image(
image, mask_image, height, width, return_image=True
)
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
if init_image.shape[1] == 4:
# if images are in latent space, we can't mask it
masked_image = None
else:
masked_image = init_image * (mask < 0.5)
# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
......
......@@ -15,12 +15,11 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
......@@ -587,14 +586,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
num_inference_steps: int = 100,
guidance_scale: float = 7.5,
image_guidance_scale: float = 1.5,
......
......@@ -34,6 +34,17 @@ class ImageProcessorTest(unittest.TestCase):
return sample
@property
def dummy_mask(self):
batch_size = 1
num_channels = 1
height = 8
width = 8
sample = torch.rand((batch_size, num_channels, height, width))
return sample
def to_np(self, image):
if isinstance(image[0], PIL.Image.Image):
return np.stack([np.array(i) for i in image], axis=0)
......@@ -133,17 +144,144 @@ class ImageProcessorTest(unittest.TestCase):
)
input_np_4d = self.to_np(self.dummy_sample)
list(input_np_4d)
input_np_list = list(input_np_4d)
out_np_4d = image_processor.postprocess(
image_processor.preprocess(input_pt_4d),
image_processor.preprocess(input_np_4d),
output_type="np",
)
out_np_list = image_processor.postprocess(
image_processor.preprocess(input_pt_list),
image_processor.preprocess(input_np_list),
output_type="np",
)
assert np.abs(out_pt_4d - out_pt_list).max() < 1e-6
assert np.abs(out_np_4d - out_np_list).max() < 1e-6
def test_preprocess_input_mask_3d(self):
image_processor = VaeImageProcessor(
do_resize=False, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
input_pt_4d = self.dummy_mask
input_pt_3d = input_pt_4d.squeeze(0)
input_pt_2d = input_pt_3d.squeeze(0)
out_pt_4d = image_processor.postprocess(
image_processor.preprocess(input_pt_4d),
output_type="np",
)
out_pt_3d = image_processor.postprocess(
image_processor.preprocess(input_pt_3d),
output_type="np",
)
out_pt_2d = image_processor.postprocess(
image_processor.preprocess(input_pt_2d),
output_type="np",
)
input_np_4d = self.to_np(self.dummy_mask)
input_np_3d = input_np_4d.squeeze(0)
input_np_3d_1 = input_np_4d.squeeze(-1)
input_np_2d = input_np_3d.squeeze(-1)
out_np_4d = image_processor.postprocess(
image_processor.preprocess(input_np_4d),
output_type="np",
)
out_np_3d = image_processor.postprocess(
image_processor.preprocess(input_np_3d),
output_type="np",
)
out_np_3d_1 = image_processor.postprocess(
image_processor.preprocess(input_np_3d_1),
output_type="np",
)
out_np_2d = image_processor.postprocess(
image_processor.preprocess(input_np_2d),
output_type="np",
)
assert np.abs(out_pt_4d - out_pt_3d).max() == 0
assert np.abs(out_pt_4d - out_pt_2d).max() == 0
assert np.abs(out_np_4d - out_np_3d).max() == 0
assert np.abs(out_np_4d - out_np_3d_1).max() == 0
assert np.abs(out_np_4d - out_np_2d).max() == 0
def test_preprocess_input_mask_list(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False, do_convert_grayscale=True)
input_pt_4d = self.dummy_mask
input_pt_3d = input_pt_4d.squeeze(0)
input_pt_2d = input_pt_3d.squeeze(0)
inputs_pt = [input_pt_4d, input_pt_3d, input_pt_2d]
inputs_pt_list = [[input_pt] for input_pt in inputs_pt]
for input_pt, input_pt_list in zip(inputs_pt, inputs_pt_list):
out_pt = image_processor.postprocess(
image_processor.preprocess(input_pt),
output_type="np",
)
out_pt_list = image_processor.postprocess(
image_processor.preprocess(input_pt_list),
output_type="np",
)
assert np.abs(out_pt - out_pt_list).max() < 1e-6
input_np_4d = self.to_np(self.dummy_mask)
input_np_3d = input_np_4d.squeeze(0)
input_np_2d = input_np_3d.squeeze(-1)
inputs_np = [input_np_4d, input_np_3d, input_np_2d]
inputs_np_list = [[input_np] for input_np in inputs_np]
for input_np, input_np_list in zip(inputs_np, inputs_np_list):
out_np = image_processor.postprocess(
image_processor.preprocess(input_np),
output_type="np",
)
out_np_list = image_processor.postprocess(
image_processor.preprocess(input_np_list),
output_type="np",
)
assert np.abs(out_np - out_np_list).max() < 1e-6
def test_preprocess_input_mask_3d_batch(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False, do_convert_grayscale=True)
# create a dummy mask input with batch_size 2
dummy_mask_batch = torch.cat([self.dummy_mask] * 2, axis=0)
# squeeze out the channel dimension
input_pt_3d = dummy_mask_batch.squeeze(1)
input_np_3d = self.to_np(dummy_mask_batch).squeeze(-1)
input_pt_3d_list = list(input_pt_3d)
input_np_3d_list = list(input_np_3d)
out_pt_3d = image_processor.postprocess(
image_processor.preprocess(input_pt_3d),
output_type="np",
)
out_pt_3d_list = image_processor.postprocess(
image_processor.preprocess(input_pt_3d_list),
output_type="np",
)
assert np.abs(out_pt_3d - out_pt_3d_list).max() < 1e-6
out_np_3d = image_processor.postprocess(
image_processor.preprocess(input_np_3d),
output_type="np",
)
out_np_3d_list = image_processor.postprocess(
image_processor.preprocess(input_np_3d_list),
output_type="np",
)
assert np.abs(out_np_3d - out_np_3d_list).max() < 1e-6
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment