"docs/vscode:/vscode.git/clone" did not exist on "b1290d3fb8aadd0d4423c5a62711f0bf490ab577"
Unverified Commit 7de68b0d authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

unwrap features in dispatchers (#6831)

* unwrap features in dispatchers

* cleanup

* align erase / five_crop / ten_crop with other dispatchers
parent 0d7807d5
......@@ -34,10 +34,15 @@ def erase(
v: torch.Tensor,
inplace: bool = False,
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, features.Image):
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Image.wrap_like(inpt, output)
elif isinstance(inpt, features.Video):
output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Video.wrap_like(inpt, output)
else: # isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
......@@ -25,12 +25,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def rgb_to_grayscale(
inpt: Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT], num_output_channels: int = 1
) -> Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT]:
old_color_space = (
features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
if isinstance(inpt, torch.Tensor)
and (torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)))
else None
)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
inpt = inpt.as_subclass(torch.Tensor)
old_color_space = None
elif isinstance(inpt, torch.Tensor):
old_color_space = features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
else:
old_color_space = None
call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = (
......
......@@ -1400,12 +1400,16 @@ def five_crop(
inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
# TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type]
output = tmp # type: ignore[assignment]
return output
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
return five_crop_image_tensor(inpt, size)
elif isinstance(inpt, features.Image):
output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
return tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
elif isinstance(inpt, features.Video):
output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
return tuple(features.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
else: # isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size)
......@@ -1444,10 +1448,15 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
def ten_crop(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]:
if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type]
return output
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
elif isinstance(inpt, features.Image):
output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return [features.Image.wrap_like(inpt, item) for item in output]
elif isinstance(inpt, features.Video):
output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return [features.Video.wrap_like(inpt, item) for item in output]
else: # isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment