"git@developer.sourcefind.cn:change/sglang.git" did not exist on "0edda32001938b578976409216bc6f9f36f719df"
Unverified Commit 8ccb6194 authored by Nikhil Gajendrakumar's avatar Nikhil Gajendrakumar Committed by GitHub
Browse files

VaeImageProcessor: Allow image resizing also for torch and numpy inputs (#4832)


Co-authored-by: default avatarNikhil Gajendrakumar <nikhilkatte@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0699ac62
...@@ -202,14 +202,27 @@ class VaeImageProcessor(ConfigMixin): ...@@ -202,14 +202,27 @@ class VaeImageProcessor(ConfigMixin):
def resize( def resize(
self, self,
image: PIL.Image.Image, image: [PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
) -> PIL.Image.Image: ) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
""" """
Resize a PIL image. Resize image.
""" """
if isinstance(image, PIL.Image.Image):
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
elif isinstance(image, torch.Tensor):
image = torch.nn.functional.interpolate(
image,
size=(height, width),
)
elif isinstance(image, np.ndarray):
image = self.numpy_to_pt(image)
image = torch.nn.functional.interpolate(
image,
size=(height, width),
)
image = self.pt_to_numpy(image)
return image return image
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
...@@ -273,11 +286,8 @@ class VaeImageProcessor(ConfigMixin): ...@@ -273,11 +286,8 @@ class VaeImageProcessor(ConfigMixin):
image = self.numpy_to_pt(image) image = self.numpy_to_pt(image)
height, width = self.get_default_height_width(image, height, width) height, width = self.get_default_height_width(image, height, width)
if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width): if self.config.do_resize:
raise ValueError( image = self.resize(image, height, width)
f"Currently we only support resizing for PIL image - please resize your numpy array to be {height} and {width}"
f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
elif isinstance(image[0], torch.Tensor): elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
...@@ -291,11 +301,8 @@ class VaeImageProcessor(ConfigMixin): ...@@ -291,11 +301,8 @@ class VaeImageProcessor(ConfigMixin):
return image return image
height, width = self.get_default_height_width(image, height, width) height, width = self.get_default_height_width(image, height, width)
if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width): if self.config.do_resize:
raise ValueError( image = self.resize(image, height, width)
f"Currently we only support resizing for PIL image - please resize your torch tensor to be {height} and {width}"
f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)
# expected range [0,1], normalize to [-1,1] # expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize do_normalize = self.config.do_normalize
......
...@@ -285,3 +285,26 @@ class ImageProcessorTest(unittest.TestCase): ...@@ -285,3 +285,26 @@ class ImageProcessorTest(unittest.TestCase):
) )
assert np.abs(out_np_3d - out_np_3d_list).max() < 1e-6 assert np.abs(out_np_3d - out_np_3d_list).max() < 1e-6
def test_vae_image_processor_resize_pt(self):
image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1)
input_pt = self.dummy_sample
b, c, h, w = input_pt.shape
scale = 2
out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale)
exp_pt_shape = (b, c, h // scale, w // scale)
assert (
out_pt.shape == exp_pt_shape
), f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
def test_vae_image_processor_resize_np(self):
image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1)
input_pt = self.dummy_sample
b, c, h, w = input_pt.shape
scale = 2
input_np = self.to_np(input_pt)
out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale)
exp_np_shape = (b, h // scale, w // scale, c)
assert (
out_np.shape == exp_np_shape
), f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment