Unverified Commit 92d2baf6 authored by Charchit Sharma's avatar Charchit Sharma Committed by GitHub
Browse files

refactor image_processor.py file (#9608)



* refactor image_processor file

* changes as requested

* +1 edits

* quality fix

* indent issue

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent dccf39f0
...@@ -38,16 +38,44 @@ PipelineImageInput = Union[ ...@@ -38,16 +38,44 @@ PipelineImageInput = Union[
PipelineDepthInput = PipelineImageInput PipelineDepthInput = PipelineImageInput
def is_valid_image(image): def is_valid_image(image) -> bool:
r"""
Checks if the input is a valid image.
A valid image can be:
- A `PIL.Image.Image`.
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
Args:
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
Returns:
`bool`:
`True` if the input is a valid image, `False` otherwise.
"""
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
def is_valid_image_imagelist(images): def is_valid_image_imagelist(images):
# check if the image input is one of the supported formats for image and image list: r"""
# it can be either one of below 3 Checks if the input is a valid image or list of images.
# (1) a 4d pytorch tensor or numpy array,
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor The input can be one of the following formats:
# (3) a list of valid image - A 4D tensor or numpy array (batch of images).
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
`torch.Tensor`.
- A list of valid images.
Args:
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
images.
Returns:
`bool`:
`True` if the input is valid, `False` otherwise.
"""
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
return True return True
elif is_valid_image(images): elif is_valid_image(images):
...@@ -103,8 +131,16 @@ class VaeImageProcessor(ConfigMixin): ...@@ -103,8 +131,16 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod @staticmethod
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
""" r"""
Convert a numpy image or a batch of images to a PIL image. Convert a numpy image or a batch of images to a PIL image.
Args:
images (`np.ndarray`):
The image array to convert to PIL format.
Returns:
`List[PIL.Image.Image]`:
A list of PIL images.
""" """
if images.ndim == 3: if images.ndim == 3:
images = images[None, ...] images = images[None, ...]
...@@ -119,8 +155,16 @@ class VaeImageProcessor(ConfigMixin): ...@@ -119,8 +155,16 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod @staticmethod
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
""" r"""
Convert a PIL image or a list of PIL images to NumPy arrays. Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
The PIL image or list of images to convert to NumPy format.
Returns:
`np.ndarray`:
A NumPy array representation of the images.
""" """
if not isinstance(images, list): if not isinstance(images, list):
images = [images] images = [images]
...@@ -131,8 +175,16 @@ class VaeImageProcessor(ConfigMixin): ...@@ -131,8 +175,16 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod @staticmethod
def numpy_to_pt(images: np.ndarray) -> torch.Tensor: def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
""" r"""
Convert a NumPy image to a PyTorch tensor. Convert a NumPy image to a PyTorch tensor.
Args:
images (`np.ndarray`):
The NumPy image array to convert to PyTorch format.
Returns:
`torch.Tensor`:
A PyTorch tensor representation of the images.
""" """
if images.ndim == 3: if images.ndim == 3:
images = images[..., None] images = images[..., None]
...@@ -142,30 +194,62 @@ class VaeImageProcessor(ConfigMixin): ...@@ -142,30 +194,62 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod @staticmethod
def pt_to_numpy(images: torch.Tensor) -> np.ndarray: def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
""" r"""
Convert a PyTorch tensor to a NumPy image. Convert a PyTorch tensor to a NumPy image.
Args:
images (`torch.Tensor`):
The PyTorch tensor to convert to NumPy format.
Returns:
`np.ndarray`:
A NumPy array representation of the images.
""" """
images = images.cpu().permute(0, 2, 3, 1).float().numpy() images = images.cpu().permute(0, 2, 3, 1).float().numpy()
return images return images
@staticmethod @staticmethod
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
""" r"""
Normalize an image array to [-1,1]. Normalize an image array to [-1,1].
Args:
images (`np.ndarray` or `torch.Tensor`):
The image array to normalize.
Returns:
`np.ndarray` or `torch.Tensor`:
The normalized image array.
""" """
return 2.0 * images - 1.0 return 2.0 * images - 1.0
@staticmethod @staticmethod
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
""" r"""
Denormalize an image array to [0,1]. Denormalize an image array to [0,1].
Args:
images (`np.ndarray` or `torch.Tensor`):
The image array to denormalize.
Returns:
`np.ndarray` or `torch.Tensor`:
The denormalized image array.
""" """
return (images / 2 + 0.5).clamp(0, 1) return (images / 2 + 0.5).clamp(0, 1)
@staticmethod @staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
""" r"""
Converts a PIL image to RGB format. Converts a PIL image to RGB format.
Args:
image (`PIL.Image.Image`):
The PIL image to convert to RGB.
Returns:
`PIL.Image.Image`:
The RGB-converted PIL image.
""" """
image = image.convert("RGB") image = image.convert("RGB")
...@@ -173,8 +257,16 @@ class VaeImageProcessor(ConfigMixin): ...@@ -173,8 +257,16 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod @staticmethod
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
""" r"""
Converts a PIL image to grayscale format. Converts a given PIL image to grayscale.
Args:
image (`PIL.Image.Image`):
The input image to convert.
Returns:
`PIL.Image.Image`:
The image converted to grayscale.
""" """
image = image.convert("L") image = image.convert("L")
...@@ -182,8 +274,16 @@ class VaeImageProcessor(ConfigMixin): ...@@ -182,8 +274,16 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod @staticmethod
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image: def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
""" r"""
Applies Gaussian blur to an image. Applies Gaussian blur to an image.
Args:
image (`PIL.Image.Image`):
The PIL image to convert to grayscale.
Returns:
`PIL.Image.Image`:
The grayscale-converted PIL image.
""" """
image = image.filter(ImageFilter.GaussianBlur(blur_factor)) image = image.filter(ImageFilter.GaussianBlur(blur_factor))
...@@ -191,7 +291,7 @@ class VaeImageProcessor(ConfigMixin): ...@@ -191,7 +291,7 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod @staticmethod
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0): def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
""" r"""
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
processing are 512x512, the region will be expanded to 128x128. processing are 512x512, the region will be expanded to 128x128.
...@@ -285,14 +385,21 @@ class VaeImageProcessor(ConfigMixin): ...@@ -285,14 +385,21 @@ class VaeImageProcessor(ConfigMixin):
width: int, width: int,
height: int, height: int,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" r"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, filling empty with data from image. the image within the dimensions, filling empty with data from image.
Args: Args:
image: The image to resize. image (`PIL.Image.Image`):
width: The width to resize the image to. The image to resize and fill.
height: The height to resize the image to. width (`int`):
The width to resize the image to.
height (`int`):
The height to resize the image to.
Returns:
`PIL.Image.Image`:
The resized and filled image.
""" """
ratio = width / height ratio = width / height
...@@ -330,14 +437,21 @@ class VaeImageProcessor(ConfigMixin): ...@@ -330,14 +437,21 @@ class VaeImageProcessor(ConfigMixin):
width: int, width: int,
height: int, height: int,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" r"""
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
the image within the dimensions, cropping the excess. the image within the dimensions, cropping the excess.
Args: Args:
image: The image to resize. image (`PIL.Image.Image`):
width: The width to resize the image to. The image to resize and crop.
height: The height to resize the image to. width (`int`):
The width to resize the image to.
height (`int`):
The height to resize the image to.
Returns:
`PIL.Image.Image`:
The resized and cropped image.
""" """
ratio = width / height ratio = width / height
src_ratio = image.width / image.height src_ratio = image.width / image.height
...@@ -429,19 +543,23 @@ class VaeImageProcessor(ConfigMixin): ...@@ -429,19 +543,23 @@ class VaeImageProcessor(ConfigMixin):
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
""" r"""
This function return the height and width that are downscaled to the next integer multiple of Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
`vae_scale_factor`.
Args: Args:
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
have shape `[batch, channel, height, width]`. tensor, it should have shape `[batch, channels, height, width]`.
height (`int`, *optional*, defaults to `None`): height (`Optional[int]`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the height of `image` input. The height of the preprocessed image. If `None`, the height of the `image` input will be used.
width (`int`, *optional*`, defaults to `None`): width (`Optional[int]`, *optional*, defaults to `None`):
The width in preprocessed. If `None`, will use the width of the `image` input. The width of the preprocessed image. If `None`, the width of the `image` input will be used.
Returns:
`Tuple[int, int]`:
A tuple containing the height and width, both resized to the nearest integer multiple of
`vae_scale_factor`.
""" """
if height is None: if height is None:
...@@ -478,13 +596,13 @@ class VaeImageProcessor(ConfigMixin): ...@@ -478,13 +596,13 @@ class VaeImageProcessor(ConfigMixin):
Preprocess the image input. Preprocess the image input.
Args: Args:
image (`pipeline_image_input`): image (`PipelineImageInput`):
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
supported formats. supported formats.
height (`int`, *optional*, defaults to `None`): height (`int`, *optional*):
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
height. height.
width (`int`, *optional*`, defaults to `None`): width (`int`, *optional*):
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
resize_mode (`str`, *optional*, defaults to `default`): resize_mode (`str`, *optional*, defaults to `default`):
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
...@@ -496,6 +614,10 @@ class VaeImageProcessor(ConfigMixin): ...@@ -496,6 +614,10 @@ class VaeImageProcessor(ConfigMixin):
supported for PIL image input. supported for PIL image input.
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image. The crop coordinates for each image in the batch. If `None`, will not crop the image.
Returns:
`torch.Tensor`:
The preprocessed image.
""" """
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
...@@ -655,8 +777,22 @@ class VaeImageProcessor(ConfigMixin): ...@@ -655,8 +777,22 @@ class VaeImageProcessor(ConfigMixin):
image: PIL.Image.Image, image: PIL.Image.Image,
crop_coords: Optional[Tuple[int, int, int, int]] = None, crop_coords: Optional[Tuple[int, int, int, int]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
""" r"""
overlay the inpaint output to the original image Applies an overlay of the mask and the inpainted image on the original image.
Args:
mask (`PIL.Image.Image`):
The mask image that highlights regions to overlay.
init_image (`PIL.Image.Image`):
The original image to which the overlay is applied.
image (`PIL.Image.Image`):
The image to overlay onto the original.
crop_coords (`Tuple[int, int, int, int]`, *optional*):
Coordinates to crop the image. If provided, the image will be cropped accordingly.
Returns:
`PIL.Image.Image`:
The final image with the overlay applied.
""" """
width, height = image.width, image.height width, height = image.width, image.height
...@@ -713,8 +849,16 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): ...@@ -713,8 +849,16 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
@staticmethod @staticmethod
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
""" r"""
Convert a NumPy image or a batch of images to a PIL image. Convert a NumPy image or a batch of images to a list of PIL images.
Args:
images (`np.ndarray`):
The input NumPy array of images, which can be a single image or a batch.
Returns:
`List[PIL.Image.Image]`:
A list of PIL images converted from the input NumPy array.
""" """
if images.ndim == 3: if images.ndim == 3:
images = images[None, ...] images = images[None, ...]
...@@ -729,8 +873,16 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): ...@@ -729,8 +873,16 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
@staticmethod @staticmethod
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
""" r"""
Convert a PIL image or a list of PIL images to NumPy arrays. Convert a PIL image or a list of PIL images to NumPy arrays.
Args:
images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
The input image or list of images to be converted.
Returns:
`np.ndarray`:
A NumPy array of the converted images.
""" """
if not isinstance(images, list): if not isinstance(images, list):
images = [images] images = [images]
...@@ -741,18 +893,30 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): ...@@ -741,18 +893,30 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
@staticmethod @staticmethod
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
""" r"""
Args: Convert an RGB-like depth image to a depth map.
image: RGB-like depth image
Returns: depth map Args:
image (`Union[np.ndarray, torch.Tensor]`):
The RGB-like depth image to convert.
Returns:
`Union[np.ndarray, torch.Tensor]`:
The corresponding depth map.
""" """
return image[:, :, 1] * 2**8 + image[:, :, 2] return image[:, :, 1] * 2**8 + image[:, :, 2]
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]: def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
""" r"""
Convert a NumPy depth image or a batch of images to a PIL image. Convert a NumPy depth image or a batch of images to a list of PIL images.
Args:
images (`np.ndarray`):
The input NumPy array of depth images, which can be a single image or a batch.
Returns:
`List[PIL.Image.Image]`:
A list of PIL images converted from the input NumPy depth images.
""" """
if images.ndim == 3: if images.ndim == 3:
images = images[None, ...] images = images[None, ...]
...@@ -833,8 +997,24 @@ class VaeImageProcessorLDM3D(VaeImageProcessor): ...@@ -833,8 +997,24 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
width: Optional[int] = None, width: Optional[int] = None,
target_res: Optional[int] = None, target_res: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" r"""
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors.
Args:
rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
The RGB input image, which can be a single image or a batch.
depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
The depth input image, which can be a single image or a batch.
height (`Optional[int]`, *optional*, defaults to `None`):
The desired height of the processed image. If `None`, defaults to the height of the input image.
width (`Optional[int]`, *optional*, defaults to `None`):
The desired width of the processed image. If `None`, defaults to the width of the input image.
target_res (`Optional[int]`, *optional*, defaults to `None`):
Target resolution for resizing the images. If specified, overrides height and width.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing the processed RGB and depth images as PyTorch tensors.
""" """
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
...@@ -1072,7 +1252,17 @@ class PixArtImageProcessor(VaeImageProcessor): ...@@ -1072,7 +1252,17 @@ class PixArtImageProcessor(VaeImageProcessor):
@staticmethod @staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
"""Returns binned height and width.""" r"""
Returns the binned height and width based on the aspect ratio.
Args:
height (`int`): The height of the image.
width (`int`): The width of the image.
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
Returns:
`Tuple[int, int]`: The closest binned height and width.
"""
ar = float(height / width) ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
default_hw = ratios[closest_ratio] default_hw = ratios[closest_ratio]
...@@ -1080,6 +1270,19 @@ class PixArtImageProcessor(VaeImageProcessor): ...@@ -1080,6 +1270,19 @@ class PixArtImageProcessor(VaeImageProcessor):
@staticmethod @staticmethod
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor: def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
r"""
Resizes and crops a tensor of images to the specified dimensions.
Args:
samples (`torch.Tensor`):
A tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height,
and W is the width.
new_width (`int`): The desired width of the output images.
new_height (`int`): The desired height of the output images.
Returns:
`torch.Tensor`: A tensor containing the resized and cropped images.
"""
orig_height, orig_width = samples.shape[2], samples.shape[3] orig_height, orig_width = samples.shape[2], samples.shape[3]
# Check if resizing is needed # Check if resizing is needed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment