Unverified Commit 37618552 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Unwrap features before passing them into a kernel (#6807)

* unwrap features before calling the kernels

* revert double unwrapping

* cleanup

* remove debug raise

* more cleanup
parent d0de55db
......@@ -66,15 +66,23 @@ class BoundingBox(_Feature):
format = BoundingBoxFormat.from_str(format.upper())
return BoundingBox.wrap_like(
self, self._F.convert_format_bounding_box(self, old_format=self.format, new_format=format), format=format
self,
self._F.convert_format_bounding_box(
self.as_subclass(torch.Tensor), old_format=self.format, new_format=format
),
format=format,
)
def horizontal_flip(self) -> BoundingBox:
output = self._F.horizontal_flip_bounding_box(self, format=self.format, spatial_size=self.spatial_size)
output = self._F.horizontal_flip_bounding_box(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
)
return BoundingBox.wrap_like(self, output)
def vertical_flip(self) -> BoundingBox:
output = self._F.vertical_flip_bounding_box(self, format=self.format, spatial_size=self.spatial_size)
output = self._F.vertical_flip_bounding_box(
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size
)
return BoundingBox.wrap_like(self, output)
def resize( # type: ignore[override]
......@@ -85,19 +93,19 @@ class BoundingBox(_Feature):
antialias: bool = False,
) -> BoundingBox:
output, spatial_size = self._F.resize_bounding_box(
self, spatial_size=self.spatial_size, size=size, max_size=max_size
self.as_subclass(torch.Tensor), spatial_size=self.spatial_size, size=size, max_size=max_size
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
output, spatial_size = self._F.crop_bounding_box(
self, self.format, top=top, left=left, height=height, width=width
self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def center_crop(self, output_size: List[int]) -> BoundingBox:
output, spatial_size = self._F.center_crop_bounding_box(
self, format=self.format, spatial_size=self.spatial_size, output_size=output_size
self.as_subclass(torch.Tensor), format=self.format, spatial_size=self.spatial_size, output_size=output_size
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
......@@ -111,7 +119,9 @@ class BoundingBox(_Feature):
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> BoundingBox:
output, spatial_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
output, spatial_size = self._F.resized_crop_bounding_box(
self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
def pad(
......@@ -121,7 +131,11 @@ class BoundingBox(_Feature):
padding_mode: str = "constant",
) -> BoundingBox:
output, spatial_size = self._F.pad_bounding_box(
self, format=self.format, spatial_size=self.spatial_size, padding=padding, padding_mode=padding_mode
self.as_subclass(torch.Tensor),
format=self.format,
spatial_size=self.spatial_size,
padding=padding,
padding_mode=padding_mode,
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
......@@ -134,7 +148,12 @@ class BoundingBox(_Feature):
center: Optional[List[float]] = None,
) -> BoundingBox:
output, spatial_size = self._F.rotate_bounding_box(
self, format=self.format, spatial_size=self.spatial_size, angle=angle, expand=expand, center=center
self.as_subclass(torch.Tensor),
format=self.format,
spatial_size=self.spatial_size,
angle=angle,
expand=expand,
center=center,
)
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
......@@ -149,7 +168,7 @@ class BoundingBox(_Feature):
center: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.affine_bounding_box(
self,
self.as_subclass(torch.Tensor),
self.format,
self.spatial_size,
angle,
......@@ -166,7 +185,7 @@ class BoundingBox(_Feature):
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
output = self._F.perspective_bounding_box(self.as_subclass(torch.Tensor), self.format, perspective_coeffs)
return BoundingBox.wrap_like(self, output)
def elastic(
......@@ -175,5 +194,5 @@ class BoundingBox(_Feature):
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(self, self.format, displacement)
output = self._F.elastic_bounding_box(self.as_subclass(torch.Tensor), self.format, displacement)
return BoundingBox.wrap_like(self, output)
......@@ -117,17 +117,17 @@ class Image(_Feature):
return Image.wrap_like(
self,
self._F.convert_color_space_image_tensor(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
self.as_subclass(torch.Tensor), old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)
def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self)
output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def vertical_flip(self) -> Image:
output = self._F.vertical_flip_image_tensor(self)
output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def resize( # type: ignore[override]
......@@ -138,16 +138,16 @@ class Image(_Feature):
antialias: bool = False,
) -> Image:
output = self._F.resize_image_tensor(
self, size, interpolation=interpolation, max_size=max_size, antialias=antialias
self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias
)
return Image.wrap_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Image:
output = self._F.crop_image_tensor(self, top, left, height, width)
output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width)
return Image.wrap_like(self, output)
def center_crop(self, output_size: List[int]) -> Image:
output = self._F.center_crop_image_tensor(self, output_size=output_size)
output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size)
return Image.wrap_like(self, output)
def resized_crop(
......@@ -161,7 +161,14 @@ class Image(_Feature):
antialias: bool = False,
) -> Image:
output = self._F.resized_crop_image_tensor(
self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias
self.as_subclass(torch.Tensor),
top,
left,
height,
width,
size=list(size),
interpolation=interpolation,
antialias=antialias,
)
return Image.wrap_like(self, output)
......@@ -171,7 +178,7 @@ class Image(_Feature):
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Image:
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
return Image.wrap_like(self, output)
def rotate(
......@@ -182,8 +189,8 @@ class Image(_Feature):
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F._geometry.rotate_image_tensor(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
output = self._F.rotate_image_tensor(
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
return Image.wrap_like(self, output)
......@@ -197,8 +204,8 @@ class Image(_Feature):
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F._geometry.affine_image_tensor(
self,
output = self._F.affine_image_tensor(
self.as_subclass(torch.Tensor),
angle,
translate=translate,
scale=scale,
......@@ -215,8 +222,8 @@ class Image(_Feature):
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Image:
output = self._F._geometry.perspective_image_tensor(
self, perspective_coeffs, interpolation=interpolation, fill=fill
output = self._F.perspective_image_tensor(
self.as_subclass(torch.Tensor), perspective_coeffs, interpolation=interpolation, fill=fill
)
return Image.wrap_like(self, output)
......@@ -226,55 +233,65 @@ class Image(_Feature):
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Image:
output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
output = self._F.elastic_image_tensor(
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
)
return Image.wrap_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Image:
output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor)
output = self._F.adjust_brightness_image_tensor(
self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
)
return Image.wrap_like(self, output)
def adjust_saturation(self, saturation_factor: float) -> Image:
output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor)
output = self._F.adjust_saturation_image_tensor(
self.as_subclass(torch.Tensor), saturation_factor=saturation_factor
)
return Image.wrap_like(self, output)
def adjust_contrast(self, contrast_factor: float) -> Image:
output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor)
output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
return Image.wrap_like(self, output)
def adjust_sharpness(self, sharpness_factor: float) -> Image:
output = self._F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor)
output = self._F.adjust_sharpness_image_tensor(
self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor
)
return Image.wrap_like(self, output)
def adjust_hue(self, hue_factor: float) -> Image:
output = self._F.adjust_hue_image_tensor(self, hue_factor=hue_factor)
output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
return Image.wrap_like(self, output)
def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
output = self._F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain)
output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
return Image.wrap_like(self, output)
def posterize(self, bits: int) -> Image:
output = self._F.posterize_image_tensor(self, bits=bits)
output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits)
return Image.wrap_like(self, output)
def solarize(self, threshold: float) -> Image:
output = self._F.solarize_image_tensor(self, threshold=threshold)
output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold)
return Image.wrap_like(self, output)
def autocontrast(self) -> Image:
output = self._F.autocontrast_image_tensor(self)
output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def equalize(self) -> Image:
output = self._F.equalize_image_tensor(self)
output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def invert(self) -> Image:
output = self._F.invert_image_tensor(self)
output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma)
output = self._F.gaussian_blur_image_tensor(
self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma
)
return Image.wrap_like(self, output)
......
......@@ -37,11 +37,11 @@ class Mask(_Feature):
return cast(Tuple[int, int], tuple(self.shape[-2:]))
def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_mask(self)
output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor))
return Mask.wrap_like(self, output)
def vertical_flip(self) -> Mask:
output = self._F.vertical_flip_mask(self)
output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor))
return Mask.wrap_like(self, output)
def resize( # type: ignore[override]
......@@ -51,15 +51,15 @@ class Mask(_Feature):
max_size: Optional[int] = None,
antialias: bool = False,
) -> Mask:
output = self._F.resize_mask(self, size, max_size=max_size)
output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size)
return Mask.wrap_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Mask:
output = self._F.crop_mask(self, top, left, height, width)
output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width)
return Mask.wrap_like(self, output)
def center_crop(self, output_size: List[int]) -> Mask:
output = self._F.center_crop_mask(self, output_size=output_size)
output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size)
return Mask.wrap_like(self, output)
def resized_crop(
......@@ -72,7 +72,7 @@ class Mask(_Feature):
interpolation: InterpolationMode = InterpolationMode.NEAREST,
antialias: bool = False,
) -> Mask:
output = self._F.resized_crop_mask(self, top, left, height, width, size=size)
output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
return Mask.wrap_like(self, output)
def pad(
......@@ -81,7 +81,7 @@ class Mask(_Feature):
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Mask:
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
return Mask.wrap_like(self, output)
def rotate(
......@@ -92,7 +92,7 @@ class Mask(_Feature):
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill)
output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
return Mask.wrap_like(self, output)
def affine(
......@@ -106,7 +106,7 @@ class Mask(_Feature):
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.affine_mask(
self,
self.as_subclass(torch.Tensor),
angle,
translate=translate,
scale=scale,
......@@ -122,7 +122,7 @@ class Mask(_Feature):
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.perspective_mask(self, perspective_coeffs, fill=fill)
output = self._F.perspective_mask(self.as_subclass(torch.Tensor), perspective_coeffs, fill=fill)
return Mask.wrap_like(self, output)
def elastic(
......@@ -131,5 +131,5 @@ class Mask(_Feature):
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.elastic_mask(self, displacement, fill=fill)
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
return Mask.wrap_like(self, output)
......@@ -73,17 +73,17 @@ class Video(_Feature):
return Video.wrap_like(
self,
self._F.convert_color_space_video(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
self.as_subclass(torch.Tensor), old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)
def horizontal_flip(self) -> Video:
output = self._F.horizontal_flip_video(self)
output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output)
def vertical_flip(self) -> Video:
output = self._F.vertical_flip_video(self)
output = self._F.vertical_flip_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output)
def resize( # type: ignore[override]
......@@ -93,15 +93,21 @@ class Video(_Feature):
max_size: Optional[int] = None,
antialias: bool = False,
) -> Video:
output = self._F.resize_video(self, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
output = self._F.resize_video(
self.as_subclass(torch.Tensor),
size,
interpolation=interpolation,
max_size=max_size,
antialias=antialias,
)
return Video.wrap_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Video:
output = self._F.crop_video(self, top, left, height, width)
output = self._F.crop_video(self.as_subclass(torch.Tensor), top, left, height, width)
return Video.wrap_like(self, output)
def center_crop(self, output_size: List[int]) -> Video:
output = self._F.center_crop_video(self, output_size=output_size)
output = self._F.center_crop_video(self.as_subclass(torch.Tensor), output_size=output_size)
return Video.wrap_like(self, output)
def resized_crop(
......@@ -115,7 +121,14 @@ class Video(_Feature):
antialias: bool = False,
) -> Video:
output = self._F.resized_crop_video(
self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias
self.as_subclass(torch.Tensor),
top,
left,
height,
width,
size=list(size),
interpolation=interpolation,
antialias=antialias,
)
return Video.wrap_like(self, output)
......@@ -125,7 +138,7 @@ class Video(_Feature):
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> Video:
output = self._F.pad_video(self, padding, fill=fill, padding_mode=padding_mode)
output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
return Video.wrap_like(self, output)
def rotate(
......@@ -136,8 +149,8 @@ class Video(_Feature):
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Video:
output = self._F._geometry.rotate_video(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
output = self._F.rotate_video(
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
return Video.wrap_like(self, output)
......@@ -151,8 +164,8 @@ class Video(_Feature):
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Video:
output = self._F._geometry.affine_video(
self,
output = self._F.affine_video(
self.as_subclass(torch.Tensor),
angle,
translate=translate,
scale=scale,
......@@ -169,7 +182,9 @@ class Video(_Feature):
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Video:
output = self._F._geometry.perspective_video(self, perspective_coeffs, interpolation=interpolation, fill=fill)
output = self._F.perspective_video(
self.as_subclass(torch.Tensor), perspective_coeffs, interpolation=interpolation, fill=fill
)
return Video.wrap_like(self, output)
def elastic(
......@@ -178,55 +193,57 @@ class Video(_Feature):
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Video:
output = self._F._geometry.elastic_video(self, displacement, interpolation=interpolation, fill=fill)
output = self._F.elastic_video(
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
)
return Video.wrap_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Video:
output = self._F.adjust_brightness_video(self, brightness_factor=brightness_factor)
output = self._F.adjust_brightness_video(self.as_subclass(torch.Tensor), brightness_factor=brightness_factor)
return Video.wrap_like(self, output)
def adjust_saturation(self, saturation_factor: float) -> Video:
output = self._F.adjust_saturation_video(self, saturation_factor=saturation_factor)
output = self._F.adjust_saturation_video(self.as_subclass(torch.Tensor), saturation_factor=saturation_factor)
return Video.wrap_like(self, output)
def adjust_contrast(self, contrast_factor: float) -> Video:
output = self._F.adjust_contrast_video(self, contrast_factor=contrast_factor)
output = self._F.adjust_contrast_video(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
return Video.wrap_like(self, output)
def adjust_sharpness(self, sharpness_factor: float) -> Video:
output = self._F.adjust_sharpness_video(self, sharpness_factor=sharpness_factor)
output = self._F.adjust_sharpness_video(self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor)
return Video.wrap_like(self, output)
def adjust_hue(self, hue_factor: float) -> Video:
output = self._F.adjust_hue_video(self, hue_factor=hue_factor)
output = self._F.adjust_hue_video(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
return Video.wrap_like(self, output)
def adjust_gamma(self, gamma: float, gain: float = 1) -> Video:
output = self._F.adjust_gamma_video(self, gamma=gamma, gain=gain)
output = self._F.adjust_gamma_video(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
return Video.wrap_like(self, output)
def posterize(self, bits: int) -> Video:
output = self._F.posterize_video(self, bits=bits)
output = self._F.posterize_video(self.as_subclass(torch.Tensor), bits=bits)
return Video.wrap_like(self, output)
def solarize(self, threshold: float) -> Video:
output = self._F.solarize_video(self, threshold=threshold)
output = self._F.solarize_video(self.as_subclass(torch.Tensor), threshold=threshold)
return Video.wrap_like(self, output)
def autocontrast(self) -> Video:
output = self._F.autocontrast_video(self)
output = self._F.autocontrast_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output)
def equalize(self) -> Video:
output = self._F.equalize_video(self)
output = self._F.equalize_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output)
def invert(self) -> Video:
output = self._F.invert_video(self)
output = self._F.invert_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output)
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video:
output = self._F.gaussian_blur_video(self, kernel_size=kernel_size, sigma=sigma)
output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma)
return Video.wrap_like(self, output)
......
......@@ -17,7 +17,11 @@ class ConvertBoundingBoxFormat(Transform):
self.format = format
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
output = F.convert_format_bounding_box(inpt, old_format=inpt.format, new_format=params["format"])
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
# since `convert_format_bounding_box` does not have a dispatcher function that would do that for us
output = F.convert_format_bounding_box(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=params["format"]
)
return features.BoundingBox.wrap_like(inpt, output, format=params["format"])
......@@ -31,7 +35,9 @@ class ConvertImageDtype(Transform):
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> Union[features.TensorImageType, features.TensorVideoType]:
output = F.convert_image_dtype(inpt, dtype=self.dtype)
# TODO: the `inpt.as_subclass(torch.Tensor)` call can be removed as soon as we have a proper dispatcher that
# handles this. See https://github.com/pytorch/vision/pull/6783 for details.
output = F.convert_image_dtype(inpt.as_subclass(torch.Tensor), dtype=self.dtype)
return (
output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined]
)
......@@ -70,5 +76,9 @@ class ClampBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox,)
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
output = F.clamp_bounding_box(inpt, format=inpt.format, spatial_size=inpt.spatial_size)
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
# since `clamp_bounding_box` does not have a dispatcher function that would do that for us
output = F.clamp_bounding_box(
inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size
)
return features.BoundingBox.wrap_like(inpt, output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment