Unverified Commit 8ccb6194 authored by Nikhil Gajendrakumar's avatar Nikhil Gajendrakumar Committed by GitHub
Browse files

VaeImageProcessor: Allow image resizing also for torch and numpy inputs (#4832)


Co-authored-by: default avatarNikhil Gajendrakumar <nikhilkatte@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0699ac62
......@@ -202,14 +202,27 @@ class VaeImageProcessor(ConfigMixin):
def resize(
self,
image: PIL.Image.Image,
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
) -> PIL.Image.Image:
) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Resize a PIL image.
Resize image.
"""
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
if isinstance(image, PIL.Image.Image):
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
elif isinstance(image, torch.Tensor):
image = torch.nn.functional.interpolate(
image,
size=(height, width),
)
elif isinstance(image, np.ndarray):
image = self.numpy_to_pt(image)
image = torch.nn.functional.interpolate(
image,
size=(height, width),
)
image = self.pt_to_numpy(image)
return image
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
......@@ -273,11 +286,8 @@ class VaeImageProcessor(ConfigMixin):
image = self.numpy_to_pt(image)
height, width = self.get_default_height_width(image, height, width)
if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your numpy array to be {height} and {width}"
f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
if self.config.do_resize:
image = self.resize(image, height, width)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
......@@ -291,11 +301,8 @@ class VaeImageProcessor(ConfigMixin):
return image
height, width = self.get_default_height_width(image, height, width)
if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your torch tensor to be {height} and {width}"
f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
if self.config.do_resize:
image = self.resize(image, height, width)
# expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize
......
......@@ -285,3 +285,26 @@ class ImageProcessorTest(unittest.TestCase):
)
assert np.abs(out_np_3d - out_np_3d_list).max() < 1e-6
def test_vae_image_processor_resize_pt(self):
image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1)
input_pt = self.dummy_sample
b, c, h, w = input_pt.shape
scale = 2
out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale)
exp_pt_shape = (b, c, h // scale, w // scale)
assert (
out_pt.shape == exp_pt_shape
), f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
def test_vae_image_processor_resize_np(self):
image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1)
input_pt = self.dummy_sample
b, c, h, w = input_pt.shape
scale = 2
input_np = self.to_np(input_pt)
out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale)
exp_np_shape = (b, h // scale, w // scale, c)
assert (
out_np.shape == exp_np_shape
), f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment