Unverified Commit 7de68b0d authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

unwrap features in dispatchers (#6831)

* unwrap features in dispatchers

* cleanup

* align erase / five_crop / ten_crop with other dispatchers
parent 0d7807d5
...@@ -34,10 +34,15 @@ def erase( ...@@ -34,10 +34,15 @@ def erase(
v: torch.Tensor, v: torch.Tensor,
inplace: bool = False, inplace: bool = False,
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: ) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor) and (
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): ):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return output elif isinstance(inpt, features.Image):
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Image.wrap_like(inpt, output)
elif isinstance(inpt, features.Video):
output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Video.wrap_like(inpt, output)
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
...@@ -25,12 +25,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima ...@@ -25,12 +25,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def rgb_to_grayscale( def rgb_to_grayscale(
inpt: Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT], num_output_channels: int = 1 inpt: Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT], num_output_channels: int = 1
) -> Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT]: ) -> Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT]:
old_color_space = ( if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type] inpt = inpt.as_subclass(torch.Tensor)
if isinstance(inpt, torch.Tensor) old_color_space = None
and (torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))) elif isinstance(inpt, torch.Tensor):
else None old_color_space = features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
) else:
old_color_space = None
call = ", num_output_channels=3" if num_output_channels == 3 else "" call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = ( replacement = (
......
...@@ -1400,12 +1400,16 @@ def five_crop( ...@@ -1400,12 +1400,16 @@ def five_crop(
inpt: ImageOrVideoTypeJIT, size: List[int] inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
# TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop` # TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor) and (
output = five_crop_image_tensor(inpt, size) torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): ):
tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type] return five_crop_image_tensor(inpt, size)
output = tmp # type: ignore[assignment] elif isinstance(inpt, features.Image):
return output output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
return tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
elif isinstance(inpt, features.Video):
output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
return tuple(features.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size) return five_crop_image_pil(inpt, size)
...@@ -1444,10 +1448,15 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F ...@@ -1444,10 +1448,15 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
def ten_crop( def ten_crop(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], size: List[int], vertical_flip: bool = False inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]: ) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]:
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor) and (
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): ):
output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type] return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
return output elif isinstance(inpt, features.Image):
output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return [features.Image.wrap_like(inpt, item) for item in output]
elif isinstance(inpt, features.Video):
output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return [features.Video.wrap_like(inpt, item) for item in output]
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment