Unverified Commit 4c049ca3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

replace new_like with wrap_like (#6718)

* replace new_like with wrap_like

* fix videos

* revert casting in favor of ignoring mypy
parent 3118fb52
...@@ -99,14 +99,14 @@ def test_inplace_op_no_wrapping(): ...@@ -99,14 +99,14 @@ def test_inplace_op_no_wrapping():
assert type(label) is features.Label assert type(label) is features.Label
def test_new_like(): def test_wrap_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = features.Label(tensor, categories=["foo", "bar"]) label = features.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here # any operation besides .to() and .clone() will do here
output = label * 2 output = label * 2
label_new = features.Label.new_like(label, output) label_new = features.Label.wrap_like(label, output)
assert type(label_new) is features.Label assert type(label_new) is features.Label
assert label_new.data_ptr() == output.data_ptr() assert label_new.data_ptr() == output.data_ptr()
......
...@@ -8,6 +8,7 @@ import pytest ...@@ -8,6 +8,7 @@ import pytest
import torch import torch
from common_utils import assert_equal, cpu_and_gpu from common_utils import assert_equal, cpu_and_gpu
from prototype_common_utils import ( from prototype_common_utils import (
DEFAULT_EXTRA_DIMS,
make_bounding_box, make_bounding_box,
make_bounding_boxes, make_bounding_boxes,
make_detection_mask, make_detection_mask,
...@@ -23,6 +24,8 @@ from torchvision.ops.boxes import box_iou ...@@ -23,6 +24,8 @@ from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms from torchvision.prototype import features, transforms
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
def make_vanilla_tensor_images(*args, **kwargs): def make_vanilla_tensor_images(*args, **kwargs):
for image in make_images(*args, **kwargs): for image in make_images(*args, **kwargs):
...@@ -109,13 +112,11 @@ class TestSmoke: ...@@ -109,13 +112,11 @@ class TestSmoke:
( (
transform, transform,
[ [
dict( dict(image=image, one_hot_label=one_hot_label)
image=features.Image.new_like(image, image.unsqueeze(0), dtype=torch.float), for image, one_hot_label in itertools.product(
one_hot_label=features.OneHotLabel.new_like( make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
one_hot_label, one_hot_label.unsqueeze(0), dtype=torch.float make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
),
) )
for image, one_hot_label in itertools.product(make_images(), make_one_hot_labels())
], ],
) )
for transform in [ for transform in [
...@@ -300,7 +301,7 @@ class TestRandomHorizontalFlip: ...@@ -300,7 +301,7 @@ class TestRandomHorizontalFlip:
actual = transform(input) actual = transform(input)
expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
expected = features.BoundingBox.new_like(input, data=expected_image_tensor) expected = features.BoundingBox.wrap_like(input, expected_image_tensor)
assert_equal(expected, actual) assert_equal(expected, actual)
assert actual.format == expected.format assert actual.format == expected.format
assert actual.image_size == expected.image_size assert actual.image_size == expected.image_size
...@@ -353,7 +354,7 @@ class TestRandomVerticalFlip: ...@@ -353,7 +354,7 @@ class TestRandomVerticalFlip:
actual = transform(input) actual = transform(input)
expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
expected = features.BoundingBox.new_like(input, data=expected_image_tensor) expected = features.BoundingBox.wrap_like(input, expected_image_tensor)
assert_equal(expected, actual) assert_equal(expected, actual)
assert actual.format == expected.format assert actual.format == expected.format
assert actual.image_size == expected.image_size assert actual.image_size == expected.image_size
......
...@@ -19,6 +19,13 @@ class BoundingBox(_Feature): ...@@ -19,6 +19,13 @@ class BoundingBox(_Feature):
format: BoundingBoxFormat format: BoundingBoxFormat
image_size: Tuple[int, int] image_size: Tuple[int, int]
@classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, image_size: Tuple[int, int]) -> BoundingBox:
bounding_box = tensor.as_subclass(cls)
bounding_box.format = format
bounding_box.image_size = image_size
return bounding_box
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
...@@ -29,52 +36,46 @@ class BoundingBox(_Feature): ...@@ -29,52 +36,46 @@ class BoundingBox(_Feature):
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
) -> BoundingBox: ) -> BoundingBox:
bounding_box = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper()) format = BoundingBoxFormat.from_str(format.upper())
bounding_box.format = format
bounding_box.image_size = image_size
return bounding_box return cls._wrap(tensor, format=format, image_size=image_size)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, image_size=self.image_size)
@classmethod @classmethod
def new_like( def wrap_like(
cls, cls,
other: BoundingBox, other: BoundingBox,
data: Any, tensor: torch.Tensor,
*, *,
format: Optional[Union[BoundingBoxFormat, str]] = None, format: Optional[BoundingBoxFormat] = None,
image_size: Optional[Tuple[int, int]] = None, image_size: Optional[Tuple[int, int]] = None,
**kwargs: Any,
) -> BoundingBox: ) -> BoundingBox:
return super().new_like( return cls._wrap(
other, tensor,
data,
format=format if format is not None else other.format, format=format if format is not None else other.format,
image_size=image_size if image_size is not None else other.image_size, image_size=image_size if image_size is not None else other.image_size,
**kwargs,
) )
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, image_size=self.image_size)
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper()) format = BoundingBoxFormat.from_str(format.upper())
return BoundingBox.new_like( return BoundingBox.wrap_like(
self, self._F.convert_format_bounding_box(self, old_format=self.format, new_format=format), format=format self, self._F.convert_format_bounding_box(self, old_format=self.format, new_format=format), format=format
) )
def horizontal_flip(self) -> BoundingBox: def horizontal_flip(self) -> BoundingBox:
output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size) output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size)
return BoundingBox.new_like(self, output) return BoundingBox.wrap_like(self, output)
def vertical_flip(self) -> BoundingBox: def vertical_flip(self) -> BoundingBox:
output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size) output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size)
return BoundingBox.new_like(self, output) return BoundingBox.wrap_like(self, output)
def resize( # type: ignore[override] def resize( # type: ignore[override]
self, self,
...@@ -84,19 +85,19 @@ class BoundingBox(_Feature): ...@@ -84,19 +85,19 @@ class BoundingBox(_Feature):
antialias: bool = False, antialias: bool = False,
) -> BoundingBox: ) -> BoundingBox:
output, image_size = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size) output, image_size = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size)
return BoundingBox.new_like(self, output, image_size=image_size) return BoundingBox.wrap_like(self, output, image_size=image_size)
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
output, image_size = self._F.crop_bounding_box( output, image_size = self._F.crop_bounding_box(
self, self.format, top=top, left=left, height=height, width=width self, self.format, top=top, left=left, height=height, width=width
) )
return BoundingBox.new_like(self, output, image_size=image_size) return BoundingBox.wrap_like(self, output, image_size=image_size)
def center_crop(self, output_size: List[int]) -> BoundingBox: def center_crop(self, output_size: List[int]) -> BoundingBox:
output, image_size = self._F.center_crop_bounding_box( output, image_size = self._F.center_crop_bounding_box(
self, format=self.format, image_size=self.image_size, output_size=output_size self, format=self.format, image_size=self.image_size, output_size=output_size
) )
return BoundingBox.new_like(self, output, image_size=image_size) return BoundingBox.wrap_like(self, output, image_size=image_size)
def resized_crop( def resized_crop(
self, self,
...@@ -109,7 +110,7 @@ class BoundingBox(_Feature): ...@@ -109,7 +110,7 @@ class BoundingBox(_Feature):
antialias: bool = False, antialias: bool = False,
) -> BoundingBox: ) -> BoundingBox:
output, image_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) output, image_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
return BoundingBox.new_like(self, output, image_size=image_size) return BoundingBox.wrap_like(self, output, image_size=image_size)
def pad( def pad(
self, self,
...@@ -120,7 +121,7 @@ class BoundingBox(_Feature): ...@@ -120,7 +121,7 @@ class BoundingBox(_Feature):
output, image_size = self._F.pad_bounding_box( output, image_size = self._F.pad_bounding_box(
self, format=self.format, image_size=self.image_size, padding=padding, padding_mode=padding_mode self, format=self.format, image_size=self.image_size, padding=padding, padding_mode=padding_mode
) )
return BoundingBox.new_like(self, output, image_size=image_size) return BoundingBox.wrap_like(self, output, image_size=image_size)
def rotate( def rotate(
self, self,
...@@ -133,7 +134,7 @@ class BoundingBox(_Feature): ...@@ -133,7 +134,7 @@ class BoundingBox(_Feature):
output, image_size = self._F.rotate_bounding_box( output, image_size = self._F.rotate_bounding_box(
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center
) )
return BoundingBox.new_like(self, output, image_size=image_size) return BoundingBox.wrap_like(self, output, image_size=image_size)
def affine( def affine(
self, self,
...@@ -155,7 +156,7 @@ class BoundingBox(_Feature): ...@@ -155,7 +156,7 @@ class BoundingBox(_Feature):
shear=shear, shear=shear,
center=center, center=center,
) )
return BoundingBox.new_like(self, output, dtype=output.dtype) return BoundingBox.wrap_like(self, output)
def perspective( def perspective(
self, self,
...@@ -164,7 +165,7 @@ class BoundingBox(_Feature): ...@@ -164,7 +165,7 @@ class BoundingBox(_Feature):
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype) return BoundingBox.wrap_like(self, output)
def elastic( def elastic(
self, self,
...@@ -173,4 +174,4 @@ class BoundingBox(_Feature): ...@@ -173,4 +174,4 @@ class BoundingBox(_Feature):
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.elastic_bounding_box(self, self.format, displacement) output = self._F.elastic_bounding_box(self, self.format, displacement)
return BoundingBox.new_like(self, output, dtype=output.dtype) return BoundingBox.wrap_like(self, output)
...@@ -14,6 +14,10 @@ D = TypeVar("D", bound="EncodedData") ...@@ -14,6 +14,10 @@ D = TypeVar("D", bound="EncodedData")
class EncodedData(_Feature): class EncodedData(_Feature):
@classmethod
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
...@@ -22,8 +26,13 @@ class EncodedData(_Feature): ...@@ -22,8 +26,13 @@ class EncodedData(_Feature):
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
) -> EncodedData: ) -> EncodedData:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) return cls._wrap(tensor)
@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return cls._wrap(tensor)
@classmethod @classmethod
def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D: def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:
......
...@@ -21,48 +21,39 @@ def is_simple_tensor(inpt: Any) -> bool: ...@@ -21,48 +21,39 @@ def is_simple_tensor(inpt: Any) -> bool:
class _Feature(torch.Tensor): class _Feature(torch.Tensor):
__F: Optional[ModuleType] = None __F: Optional[ModuleType] = None
def __new__( @staticmethod
cls: Type[F], def _to_tensor(
data: Any, data: Any,
*,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
) -> F: ) -> torch.Tensor:
return ( return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
torch.as_tensor( # type: ignore[return-value]
data,
dtype=dtype,
device=device,
)
.as_subclass(cls)
.requires_grad_(requires_grad)
)
@classmethod # FIXME: this is just here for BC with the prototype datasets. Some datasets use the _Feature directly to have a
def new_like( # a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be
cls: Type[F], # interpreted as images. We should decide if we want a public no-op feature like `GenericFeature` or make this one
other: F, # public again.
def __new__(
cls,
data: Any, data: Any,
*,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: Optional[bool] = None, requires_grad: bool = False,
**kwargs: Any, ) -> _Feature:
) -> F: tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls( return tensor.as_subclass(_Feature)
data,
dtype=dtype if dtype is not None else other.dtype, @classmethod
device=device if device is not None else other.device, def wrap_like(cls: Type[F], other: F, tensor: torch.Tensor) -> F:
requires_grad=requires_grad if requires_grad is not None else other.requires_grad, # FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved,
**kwargs, # this method should be made abstract
) # raise NotImplementedError
return tensor.as_subclass(cls)
_NO_WRAPPING_EXCEPTIONS = { _NO_WRAPPING_EXCEPTIONS = {
torch.Tensor.clone: lambda cls, input, output: cls.new_like(input, output), torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
torch.Tensor.to: lambda cls, input, output: cls.new_like( torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
input, output, dtype=output.dtype, device=output.device
),
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus # We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
# retains the type automatically # retains the type automatically
torch.Tensor.requires_grad_: lambda cls, input, output: output, torch.Tensor.requires_grad_: lambda cls, input, output: output,
......
...@@ -62,6 +62,12 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace: ...@@ -62,6 +62,12 @@ def _from_tensor_shape(shape: List[int]) -> ColorSpace:
class Image(_Feature): class Image(_Feature):
color_space: ColorSpace color_space: ColorSpace
@classmethod
def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Image:
image = tensor.as_subclass(cls)
image.color_space = color_space
return image
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
...@@ -71,36 +77,33 @@ class Image(_Feature): ...@@ -71,36 +77,33 @@ class Image(_Feature):
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
) -> Image: ) -> Image:
data = torch.as_tensor(data, dtype=dtype, device=device) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if data.ndim < 2: if tensor.ndim < 2:
raise ValueError raise ValueError
elif data.ndim == 2: elif tensor.ndim == 2:
data = data.unsqueeze(0) tensor = tensor.unsqueeze(0)
image = super().__new__(cls, data, requires_grad=requires_grad)
if color_space is None: if color_space is None:
color_space = ColorSpace.from_tensor_shape(image.shape) # type: ignore[arg-type] color_space = ColorSpace.from_tensor_shape(tensor.shape) # type: ignore[arg-type]
if color_space == ColorSpace.OTHER: if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str): elif isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper()) color_space = ColorSpace.from_str(color_space.upper())
elif not isinstance(color_space, ColorSpace): elif not isinstance(color_space, ColorSpace):
raise ValueError raise ValueError
image.color_space = color_space
return image return cls._wrap(tensor, color_space=color_space)
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space)
@classmethod @classmethod
def new_like( def wrap_like(cls, other: Image, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Image:
cls, other: Image, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any return cls._wrap(
) -> Image: tensor,
return super().new_like( color_space=color_space if color_space is not None else other.color_space,
other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs
) )
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space)
@property @property
def image_size(self) -> Tuple[int, int]: def image_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:])) return cast(Tuple[int, int], tuple(self.shape[-2:]))
...@@ -113,7 +116,7 @@ class Image(_Feature): ...@@ -113,7 +116,7 @@ class Image(_Feature):
if isinstance(color_space, str): if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper()) color_space = ColorSpace.from_str(color_space.upper())
return Image.new_like( return Image.wrap_like(
self, self,
self._F.convert_color_space_image_tensor( self._F.convert_color_space_image_tensor(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
...@@ -129,15 +132,15 @@ class Image(_Feature): ...@@ -129,15 +132,15 @@ class Image(_Feature):
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image: def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
# promote this out of the prototype state # promote this out of the prototype state
return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) return Image.wrap_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs))
def horizontal_flip(self) -> Image: def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self) output = self._F.horizontal_flip_image_tensor(self)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def vertical_flip(self) -> Image: def vertical_flip(self) -> Image:
output = self._F.vertical_flip_image_tensor(self) output = self._F.vertical_flip_image_tensor(self)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def resize( # type: ignore[override] def resize( # type: ignore[override]
self, self,
...@@ -149,15 +152,15 @@ class Image(_Feature): ...@@ -149,15 +152,15 @@ class Image(_Feature):
output = self._F.resize_image_tensor( output = self._F.resize_image_tensor(
self, size, interpolation=interpolation, max_size=max_size, antialias=antialias self, size, interpolation=interpolation, max_size=max_size, antialias=antialias
) )
return Image.new_like(self, output) return Image.wrap_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Image: def crop(self, top: int, left: int, height: int, width: int) -> Image:
output = self._F.crop_image_tensor(self, top, left, height, width) output = self._F.crop_image_tensor(self, top, left, height, width)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def center_crop(self, output_size: List[int]) -> Image: def center_crop(self, output_size: List[int]) -> Image:
output = self._F.center_crop_image_tensor(self, output_size=output_size) output = self._F.center_crop_image_tensor(self, output_size=output_size)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def resized_crop( def resized_crop(
self, self,
...@@ -172,7 +175,7 @@ class Image(_Feature): ...@@ -172,7 +175,7 @@ class Image(_Feature):
output = self._F.resized_crop_image_tensor( output = self._F.resized_crop_image_tensor(
self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias
) )
return Image.new_like(self, output) return Image.wrap_like(self, output)
def pad( def pad(
self, self,
...@@ -181,7 +184,7 @@ class Image(_Feature): ...@@ -181,7 +184,7 @@ class Image(_Feature):
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Image: ) -> Image:
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def rotate( def rotate(
self, self,
...@@ -194,7 +197,7 @@ class Image(_Feature): ...@@ -194,7 +197,7 @@ class Image(_Feature):
output = self._F._geometry.rotate_image_tensor( output = self._F._geometry.rotate_image_tensor(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
) )
return Image.new_like(self, output) return Image.wrap_like(self, output)
def affine( def affine(
self, self,
...@@ -216,7 +219,7 @@ class Image(_Feature): ...@@ -216,7 +219,7 @@ class Image(_Feature):
fill=fill, fill=fill,
center=center, center=center,
) )
return Image.new_like(self, output) return Image.wrap_like(self, output)
def perspective( def perspective(
self, self,
...@@ -227,7 +230,7 @@ class Image(_Feature): ...@@ -227,7 +230,7 @@ class Image(_Feature):
output = self._F._geometry.perspective_image_tensor( output = self._F._geometry.perspective_image_tensor(
self, perspective_coeffs, interpolation=interpolation, fill=fill self, perspective_coeffs, interpolation=interpolation, fill=fill
) )
return Image.new_like(self, output) return Image.wrap_like(self, output)
def elastic( def elastic(
self, self,
...@@ -236,55 +239,55 @@ class Image(_Feature): ...@@ -236,55 +239,55 @@ class Image(_Feature):
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> Image: ) -> Image:
output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill) output = self._F._geometry.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Image: def adjust_brightness(self, brightness_factor: float) -> Image:
output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor) output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def adjust_saturation(self, saturation_factor: float) -> Image: def adjust_saturation(self, saturation_factor: float) -> Image:
output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor) output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def adjust_contrast(self, contrast_factor: float) -> Image: def adjust_contrast(self, contrast_factor: float) -> Image:
output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor) output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def adjust_sharpness(self, sharpness_factor: float) -> Image: def adjust_sharpness(self, sharpness_factor: float) -> Image:
output = self._F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor) output = self._F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def adjust_hue(self, hue_factor: float) -> Image: def adjust_hue(self, hue_factor: float) -> Image:
output = self._F.adjust_hue_image_tensor(self, hue_factor=hue_factor) output = self._F.adjust_hue_image_tensor(self, hue_factor=hue_factor)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def adjust_gamma(self, gamma: float, gain: float = 1) -> Image: def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
output = self._F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain) output = self._F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def posterize(self, bits: int) -> Image: def posterize(self, bits: int) -> Image:
output = self._F.posterize_image_tensor(self, bits=bits) output = self._F.posterize_image_tensor(self, bits=bits)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def solarize(self, threshold: float) -> Image: def solarize(self, threshold: float) -> Image:
output = self._F.solarize_image_tensor(self, threshold=threshold) output = self._F.solarize_image_tensor(self, threshold=threshold)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def autocontrast(self) -> Image: def autocontrast(self) -> Image:
output = self._F.autocontrast_image_tensor(self) output = self._F.autocontrast_image_tensor(self)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def equalize(self) -> Image: def equalize(self) -> Image:
output = self._F.equalize_image_tensor(self) output = self._F.equalize_image_tensor(self)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def invert(self) -> Image: def invert(self) -> Image:
output = self._F.invert_image_tensor(self) output = self._F.invert_image_tensor(self)
return Image.new_like(self, output) return Image.wrap_like(self, output)
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma) output = self._F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma)
return Image.new_like(self, output) return Image.wrap_like(self, output)
ImageType = Union[torch.Tensor, PIL.Image.Image, Image] ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
......
...@@ -14,6 +14,12 @@ L = TypeVar("L", bound="_LabelBase") ...@@ -14,6 +14,12 @@ L = TypeVar("L", bound="_LabelBase")
class _LabelBase(_Feature): class _LabelBase(_Feature):
categories: Optional[Sequence[str]] categories: Optional[Sequence[str]]
@classmethod
def _wrap(cls: Type[L], tensor: torch.Tensor, *, categories: Optional[Sequence[str]]) -> L:
label_base = tensor.as_subclass(cls)
label_base.categories = categories
return label_base
def __new__( def __new__(
cls: Type[L], cls: Type[L],
data: Any, data: Any,
...@@ -23,16 +29,14 @@ class _LabelBase(_Feature): ...@@ -23,16 +29,14 @@ class _LabelBase(_Feature):
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
) -> L: ) -> L:
label_base = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor, categories=categories)
label_base.categories = categories
return label_base
@classmethod @classmethod
def new_like(cls: Type[L], other: L, data: Any, *, categories: Optional[Sequence[str]] = None, **kwargs: Any) -> L: def wrap_like(cls: Type[L], other: L, tensor: torch.Tensor, *, categories: Optional[Sequence[str]] = None) -> L:
return super().new_like( return cls._wrap(
other, data, categories=categories if categories is not None else other.categories, **kwargs tensor,
categories=categories if categories is not None else other.categories,
) )
@classmethod @classmethod
......
from __future__ import annotations from __future__ import annotations
from typing import List, Optional, Union from typing import Any, List, Optional, Union
import torch import torch
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
...@@ -9,13 +9,36 @@ from ._feature import _Feature, FillTypeJIT ...@@ -9,13 +9,36 @@ from ._feature import _Feature, FillTypeJIT
class Mask(_Feature): class Mask(_Feature):
@classmethod
def _wrap(cls, tensor: torch.Tensor) -> Mask:
return tensor.as_subclass(cls)
def __new__(
cls,
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False,
) -> Mask:
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor)
@classmethod
def wrap_like(
cls,
other: Mask,
tensor: torch.Tensor,
) -> Mask:
return cls._wrap(tensor)
def horizontal_flip(self) -> Mask: def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_mask(self) output = self._F.horizontal_flip_mask(self)
return Mask.new_like(self, output) return Mask.wrap_like(self, output)
def vertical_flip(self) -> Mask: def vertical_flip(self) -> Mask:
output = self._F.vertical_flip_mask(self) output = self._F.vertical_flip_mask(self)
return Mask.new_like(self, output) return Mask.wrap_like(self, output)
def resize( # type: ignore[override] def resize( # type: ignore[override]
self, self,
...@@ -25,15 +48,15 @@ class Mask(_Feature): ...@@ -25,15 +48,15 @@ class Mask(_Feature):
antialias: bool = False, antialias: bool = False,
) -> Mask: ) -> Mask:
output = self._F.resize_mask(self, size, max_size=max_size) output = self._F.resize_mask(self, size, max_size=max_size)
return Mask.new_like(self, output) return Mask.wrap_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Mask: def crop(self, top: int, left: int, height: int, width: int) -> Mask:
output = self._F.crop_mask(self, top, left, height, width) output = self._F.crop_mask(self, top, left, height, width)
return Mask.new_like(self, output) return Mask.wrap_like(self, output)
def center_crop(self, output_size: List[int]) -> Mask: def center_crop(self, output_size: List[int]) -> Mask:
output = self._F.center_crop_mask(self, output_size=output_size) output = self._F.center_crop_mask(self, output_size=output_size)
return Mask.new_like(self, output) return Mask.wrap_like(self, output)
def resized_crop( def resized_crop(
self, self,
...@@ -46,7 +69,7 @@ class Mask(_Feature): ...@@ -46,7 +69,7 @@ class Mask(_Feature):
antialias: bool = False, antialias: bool = False,
) -> Mask: ) -> Mask:
output = self._F.resized_crop_mask(self, top, left, height, width, size=size) output = self._F.resized_crop_mask(self, top, left, height, width, size=size)
return Mask.new_like(self, output) return Mask.wrap_like(self, output)
def pad( def pad(
self, self,
...@@ -55,7 +78,7 @@ class Mask(_Feature): ...@@ -55,7 +78,7 @@ class Mask(_Feature):
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Mask: ) -> Mask:
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill) output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
return Mask.new_like(self, output) return Mask.wrap_like(self, output)
def rotate( def rotate(
self, self,
...@@ -66,7 +89,7 @@ class Mask(_Feature): ...@@ -66,7 +89,7 @@ class Mask(_Feature):
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> Mask: ) -> Mask:
output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill) output = self._F.rotate_mask(self, angle, expand=expand, center=center, fill=fill)
return Mask.new_like(self, output) return Mask.wrap_like(self, output)
def affine( def affine(
self, self,
...@@ -87,7 +110,7 @@ class Mask(_Feature): ...@@ -87,7 +110,7 @@ class Mask(_Feature):
fill=fill, fill=fill,
center=center, center=center,
) )
return Mask.new_like(self, output) return Mask.wrap_like(self, output)
def perspective( def perspective(
self, self,
...@@ -96,7 +119,7 @@ class Mask(_Feature): ...@@ -96,7 +119,7 @@ class Mask(_Feature):
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> Mask: ) -> Mask:
output = self._F.perspective_mask(self, perspective_coeffs, fill=fill) output = self._F.perspective_mask(self, perspective_coeffs, fill=fill)
return Mask.new_like(self, output) return Mask.wrap_like(self, output)
def elastic( def elastic(
self, self,
...@@ -105,4 +128,4 @@ class Mask(_Feature): ...@@ -105,4 +128,4 @@ class Mask(_Feature):
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> Mask: ) -> Mask:
output = self._F.elastic_mask(self, displacement, fill=fill) output = self._F.elastic_mask(self, displacement, fill=fill)
return Mask.new_like(self, output, dtype=output.dtype) return Mask.wrap_like(self, output)
...@@ -13,6 +13,12 @@ from ._image import ColorSpace, ImageType, ImageTypeJIT, TensorImageType, Tensor ...@@ -13,6 +13,12 @@ from ._image import ColorSpace, ImageType, ImageTypeJIT, TensorImageType, Tensor
class Video(_Feature): class Video(_Feature):
color_space: ColorSpace color_space: ColorSpace
@classmethod
def _wrap(cls, tensor: torch.Tensor, *, color_space: ColorSpace) -> Video:
image = tensor.as_subclass(cls)
image.color_space = color_space
return image
def __new__( def __new__(
cls, cls,
data: Any, data: Any,
...@@ -22,7 +28,7 @@ class Video(_Feature): ...@@ -22,7 +28,7 @@ class Video(_Feature):
device: Optional[Union[torch.device, str, int]] = None, device: Optional[Union[torch.device, str, int]] = None,
requires_grad: bool = False, requires_grad: bool = False,
) -> Video: ) -> Video:
data = torch.as_tensor(data, dtype=dtype, device=device) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if data.ndim < 4: if data.ndim < 4:
raise ValueError raise ValueError
video = super().__new__(cls, data, requires_grad=requires_grad) video = super().__new__(cls, data, requires_grad=requires_grad)
...@@ -35,21 +41,19 @@ class Video(_Feature): ...@@ -35,21 +41,19 @@ class Video(_Feature):
color_space = ColorSpace.from_str(color_space.upper()) color_space = ColorSpace.from_str(color_space.upper())
elif not isinstance(color_space, ColorSpace): elif not isinstance(color_space, ColorSpace):
raise ValueError raise ValueError
video.color_space = color_space
return video
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return cls._wrap(tensor, color_space=color_space)
return self._make_repr(color_space=self.color_space)
@classmethod @classmethod
def new_like( def wrap_like(cls, other: Video, tensor: torch.Tensor, *, color_space: Optional[ColorSpace] = None) -> Video:
cls, other: Video, data: Any, *, color_space: Optional[Union[ColorSpace, str]] = None, **kwargs: Any return cls._wrap(
) -> Video: tensor,
return super().new_like( color_space=color_space if color_space is not None else other.color_space,
other, data, color_space=color_space if color_space is not None else other.color_space, **kwargs
) )
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(color_space=self.color_space)
# TODO: rename this (and all instances of this term to spatial size) # TODO: rename this (and all instances of this term to spatial size)
@property @property
def image_size(self) -> Tuple[int, int]: def image_size(self) -> Tuple[int, int]:
...@@ -67,7 +71,7 @@ class Video(_Feature): ...@@ -67,7 +71,7 @@ class Video(_Feature):
if isinstance(color_space, str): if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper()) color_space = ColorSpace.from_str(color_space.upper())
return Video.new_like( return Video.wrap_like(
self, self,
self._F.convert_color_space_video( self._F.convert_color_space_video(
self, old_color_space=self.color_space, new_color_space=color_space, copy=copy self, old_color_space=self.color_space, new_color_space=color_space, copy=copy
...@@ -77,11 +81,11 @@ class Video(_Feature): ...@@ -77,11 +81,11 @@ class Video(_Feature):
def horizontal_flip(self) -> Video: def horizontal_flip(self) -> Video:
output = self._F.horizontal_flip_video(self) output = self._F.horizontal_flip_video(self)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def vertical_flip(self) -> Video: def vertical_flip(self) -> Video:
output = self._F.vertical_flip_video(self) output = self._F.vertical_flip_video(self)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def resize( # type: ignore[override] def resize( # type: ignore[override]
self, self,
...@@ -91,15 +95,15 @@ class Video(_Feature): ...@@ -91,15 +95,15 @@ class Video(_Feature):
antialias: bool = False, antialias: bool = False,
) -> Video: ) -> Video:
output = self._F.resize_video(self, size, interpolation=interpolation, max_size=max_size, antialias=antialias) output = self._F.resize_video(self, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Video: def crop(self, top: int, left: int, height: int, width: int) -> Video:
output = self._F.crop_video(self, top, left, height, width) output = self._F.crop_video(self, top, left, height, width)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def center_crop(self, output_size: List[int]) -> Video: def center_crop(self, output_size: List[int]) -> Video:
output = self._F.center_crop_video(self, output_size=output_size) output = self._F.center_crop_video(self, output_size=output_size)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def resized_crop( def resized_crop(
self, self,
...@@ -114,7 +118,7 @@ class Video(_Feature): ...@@ -114,7 +118,7 @@ class Video(_Feature):
output = self._F.resized_crop_video( output = self._F.resized_crop_video(
self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias
) )
return Video.new_like(self, output) return Video.wrap_like(self, output)
def pad( def pad(
self, self,
...@@ -123,7 +127,7 @@ class Video(_Feature): ...@@ -123,7 +127,7 @@ class Video(_Feature):
padding_mode: str = "constant", padding_mode: str = "constant",
) -> Video: ) -> Video:
output = self._F.pad_video(self, padding, fill=fill, padding_mode=padding_mode) output = self._F.pad_video(self, padding, fill=fill, padding_mode=padding_mode)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def rotate( def rotate(
self, self,
...@@ -136,7 +140,7 @@ class Video(_Feature): ...@@ -136,7 +140,7 @@ class Video(_Feature):
output = self._F._geometry.rotate_video( output = self._F._geometry.rotate_video(
self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center
) )
return Video.new_like(self, output) return Video.wrap_like(self, output)
def affine( def affine(
self, self,
...@@ -158,7 +162,7 @@ class Video(_Feature): ...@@ -158,7 +162,7 @@ class Video(_Feature):
fill=fill, fill=fill,
center=center, center=center,
) )
return Video.new_like(self, output) return Video.wrap_like(self, output)
def perspective( def perspective(
self, self,
...@@ -167,7 +171,7 @@ class Video(_Feature): ...@@ -167,7 +171,7 @@ class Video(_Feature):
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> Video: ) -> Video:
output = self._F._geometry.perspective_video(self, perspective_coeffs, interpolation=interpolation, fill=fill) output = self._F._geometry.perspective_video(self, perspective_coeffs, interpolation=interpolation, fill=fill)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def elastic( def elastic(
self, self,
...@@ -176,55 +180,55 @@ class Video(_Feature): ...@@ -176,55 +180,55 @@ class Video(_Feature):
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> Video: ) -> Video:
output = self._F._geometry.elastic_video(self, displacement, interpolation=interpolation, fill=fill) output = self._F._geometry.elastic_video(self, displacement, interpolation=interpolation, fill=fill)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Video: def adjust_brightness(self, brightness_factor: float) -> Video:
output = self._F.adjust_brightness_video(self, brightness_factor=brightness_factor) output = self._F.adjust_brightness_video(self, brightness_factor=brightness_factor)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def adjust_saturation(self, saturation_factor: float) -> Video: def adjust_saturation(self, saturation_factor: float) -> Video:
output = self._F.adjust_saturation_video(self, saturation_factor=saturation_factor) output = self._F.adjust_saturation_video(self, saturation_factor=saturation_factor)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def adjust_contrast(self, contrast_factor: float) -> Video: def adjust_contrast(self, contrast_factor: float) -> Video:
output = self._F.adjust_contrast_video(self, contrast_factor=contrast_factor) output = self._F.adjust_contrast_video(self, contrast_factor=contrast_factor)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def adjust_sharpness(self, sharpness_factor: float) -> Video: def adjust_sharpness(self, sharpness_factor: float) -> Video:
output = self._F.adjust_sharpness_video(self, sharpness_factor=sharpness_factor) output = self._F.adjust_sharpness_video(self, sharpness_factor=sharpness_factor)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def adjust_hue(self, hue_factor: float) -> Video: def adjust_hue(self, hue_factor: float) -> Video:
output = self._F.adjust_hue_video(self, hue_factor=hue_factor) output = self._F.adjust_hue_video(self, hue_factor=hue_factor)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def adjust_gamma(self, gamma: float, gain: float = 1) -> Video: def adjust_gamma(self, gamma: float, gain: float = 1) -> Video:
output = self._F.adjust_gamma_video(self, gamma=gamma, gain=gain) output = self._F.adjust_gamma_video(self, gamma=gamma, gain=gain)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def posterize(self, bits: int) -> Video: def posterize(self, bits: int) -> Video:
output = self._F.posterize_video(self, bits=bits) output = self._F.posterize_video(self, bits=bits)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def solarize(self, threshold: float) -> Video: def solarize(self, threshold: float) -> Video:
output = self._F.solarize_video(self, threshold=threshold) output = self._F.solarize_video(self, threshold=threshold)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def autocontrast(self) -> Video: def autocontrast(self) -> Video:
output = self._F.autocontrast_video(self) output = self._F.autocontrast_video(self)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def equalize(self) -> Video: def equalize(self) -> Video:
output = self._F.equalize_video(self) output = self._F.equalize_video(self)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def invert(self) -> Video: def invert(self) -> Video:
output = self._F.invert_video(self) output = self._F.invert_video(self)
return Video.new_like(self, output) return Video.wrap_like(self, output)
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video:
output = self._F.gaussian_blur_video(self, kernel_size=kernel_size, sigma=sigma) output = self._F.gaussian_blur_video(self, kernel_size=kernel_size, sigma=sigma)
return Video.new_like(self, output) return Video.wrap_like(self, output)
VideoType = Union[torch.Tensor, Video] VideoType = Union[torch.Tensor, Video]
......
...@@ -119,7 +119,7 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -119,7 +119,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
raise ValueError("Need a batch of one hot labels") raise ValueError("Need a batch of one hot labels")
output = inpt.clone() output = inpt.clone()
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam)) output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
return features.OneHotLabel.new_like(inpt, output) return features.OneHotLabel.wrap_like(inpt, output)
class RandomMixup(_BaseMixupCutmix): class RandomMixup(_BaseMixupCutmix):
...@@ -135,7 +135,7 @@ class RandomMixup(_BaseMixupCutmix): ...@@ -135,7 +135,7 @@ class RandomMixup(_BaseMixupCutmix):
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam)) output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output) output = features.Image.wrap_like(inpt, output)
return output return output
elif isinstance(inpt, features.OneHotLabel): elif isinstance(inpt, features.OneHotLabel):
...@@ -178,7 +178,7 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -178,7 +178,7 @@ class RandomCutmix(_BaseMixupCutmix):
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output) output = features.Image.wrap_like(inpt, output)
return output return output
elif isinstance(inpt, features.OneHotLabel): elif isinstance(inpt, features.OneHotLabel):
...@@ -213,9 +213,11 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -213,9 +213,11 @@ class SimpleCopyPaste(_RandomApplyTransform):
antialias: Optional[bool], antialias: Optional[bool],
) -> Tuple[features.TensorImageType, Dict[str, Any]]: ) -> Tuple[features.TensorImageType, Dict[str, Any]]:
paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection]) paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
paste_labels = paste_target["labels"].new_like(paste_target["labels"], paste_target["labels"][random_selection]) paste_labels = paste_target["labels"].wrap_like(
paste_target["labels"], paste_target["labels"][random_selection]
)
masks = target["masks"] masks = target["masks"]
...@@ -317,7 +319,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -317,7 +319,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
c0, c1, c2, c3 = 0, 0, 0, 0 c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample): for i, obj in enumerate(flat_sample):
if isinstance(obj, features.Image): if isinstance(obj, features.Image):
flat_sample[i] = features.Image.new_like(obj, output_images[c0]) flat_sample[i] = features.Image.wrap_like(obj, output_images[c0])
c0 += 1 c0 += 1
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0]) flat_sample[i] = F.to_image_pil(output_images[c0])
...@@ -326,13 +328,13 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -326,13 +328,13 @@ class SimpleCopyPaste(_RandomApplyTransform):
flat_sample[i] = output_images[c0] flat_sample[i] = output_images[c0]
c0 += 1 c0 += 1
elif isinstance(obj, features.BoundingBox): elif isinstance(obj, features.BoundingBox):
flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"]) flat_sample[i] = features.BoundingBox.wrap_like(obj, output_targets[c1]["boxes"])
c1 += 1 c1 += 1
elif isinstance(obj, features.Mask): elif isinstance(obj, features.Mask):
flat_sample[i] = features.Mask.new_like(obj, output_targets[c2]["masks"]) flat_sample[i] = features.Mask.wrap_like(obj, output_targets[c2]["masks"])
c2 += 1 c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)): elif isinstance(obj, (features.Label, features.OneHotLabel)):
flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
c3 += 1 c3 += 1
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
......
...@@ -520,7 +520,7 @@ class AugMix(_AutoAugmentBase): ...@@ -520,7 +520,7 @@ class AugMix(_AutoAugmentBase):
mix = mix.view(orig_dims).to(dtype=image_or_video.dtype) mix = mix.view(orig_dims).to(dtype=image_or_video.dtype)
if isinstance(orig_image_or_video, (features.Image, features.Video)): if isinstance(orig_image_or_video, (features.Image, features.Video)):
mix = type(orig_image_or_video).new_like(orig_image_or_video, mix) # type: ignore[arg-type] mix = type(orig_image_or_video).wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
elif isinstance(orig_image_or_video, PIL.Image.Image): elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix) mix = F.to_image_pil(mix)
......
...@@ -119,7 +119,8 @@ class RandomPhotometricDistort(Transform): ...@@ -119,7 +119,8 @@ class RandomPhotometricDistort(Transform):
output = inpt[..., permutation, :, :] output = inpt[..., permutation, :, :]
if isinstance(inpt, (features.Image, features.Video)): if isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).new_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type] output = type(inpt).wrap_like(inpt, output, color_space=features.ColorSpace.OTHER) # type: ignore[arg-type]
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
output = F.to_image_pil(output) output = F.to_image_pil(output)
......
...@@ -55,7 +55,7 @@ class Grayscale(Transform): ...@@ -55,7 +55,7 @@ class Grayscale(Transform):
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY)
return output return output
...@@ -84,5 +84,5 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -84,5 +84,5 @@ class RandomGrayscale(_RandomApplyTransform):
def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType: def _transform(self, inpt: features.ImageType, params: Dict[str, Any]) -> features.ImageType:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, features.Image): if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.GRAY) output = features.Image.wrap_like(inpt, output, color_space=features.ColorSpace.GRAY)
return output return output
...@@ -158,8 +158,8 @@ class FiveCrop(Transform): ...@@ -158,8 +158,8 @@ class FiveCrop(Transform):
... def forward(self, sample: Tuple[Tuple[features.Image, ...], features.Label]): ... def forward(self, sample: Tuple[Tuple[features.Image, ...], features.Label]):
... images, labels = sample ... images, labels = sample
... batch_size = len(images) ... batch_size = len(images)
... images = features.Image.new_like(images[0], torch.stack(images)) ... images = features.Image.wrap_like(images[0], torch.stack(images))
... labels = features.Label.new_like(labels, labels.repeat(batch_size)) ... labels = features.Label.wrap_like(labels, labels.repeat(batch_size))
... return images, labels ... return images, labels
... ...
>>> image = features.Image(torch.rand(3, 256, 256)) >>> image = features.Image(torch.rand(3, 256, 256))
...@@ -677,18 +677,18 @@ class RandomIoUCrop(Transform): ...@@ -677,18 +677,18 @@ class RandomIoUCrop(Transform):
is_within_crop_area = params["is_within_crop_area"] is_within_crop_area = params["is_within_crop_area"]
if isinstance(inpt, (features.Label, features.OneHotLabel)): if isinstance(inpt, (features.Label, features.OneHotLabel)):
return inpt.new_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type] return inpt.wrap_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type]
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
if isinstance(output, features.BoundingBox): if isinstance(output, features.BoundingBox):
bboxes = output[is_within_crop_area] bboxes = output[is_within_crop_area]
bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size) bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size)
output = features.BoundingBox.new_like(output, bboxes) output = features.BoundingBox.wrap_like(output, bboxes)
elif isinstance(output, features.Mask): elif isinstance(output, features.Mask):
# apply is_within_crop_area if mask is one-hot encoded # apply is_within_crop_area if mask is one-hot encoded
masks = output[is_within_crop_area] masks = output[is_within_crop_area]
output = features.Mask.new_like(output, masks) output = features.Mask.wrap_like(output, masks)
return output return output
...@@ -801,7 +801,7 @@ class FixedSizeCrop(Transform): ...@@ -801,7 +801,7 @@ class FixedSizeCrop(Transform):
bounding_boxes = cast( bounding_boxes = cast(
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=new_height, width=new_width) features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=new_height, width=new_width)
) )
bounding_boxes = features.BoundingBox.new_like( bounding_boxes = features.BoundingBox.wrap_like(
bounding_boxes, bounding_boxes,
F.clamp_bounding_box( F.clamp_bounding_box(
bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size
...@@ -840,9 +840,9 @@ class FixedSizeCrop(Transform): ...@@ -840,9 +840,9 @@ class FixedSizeCrop(Transform):
if params["is_valid"] is not None: if params["is_valid"] is not None:
if isinstance(inpt, (features.Label, features.OneHotLabel, features.Mask)): if isinstance(inpt, (features.Label, features.OneHotLabel, features.Mask)):
inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, features.BoundingBox): elif isinstance(inpt, features.BoundingBox):
inpt = features.BoundingBox.new_like( inpt = features.BoundingBox.wrap_like(
inpt, inpt,
F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size), F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size),
) )
......
...@@ -18,7 +18,7 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -18,7 +18,7 @@ class ConvertBoundingBoxFormat(Transform):
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox: def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
output = F.convert_format_bounding_box(inpt, old_format=inpt.format, new_format=params["format"]) output = F.convert_format_bounding_box(inpt, old_format=inpt.format, new_format=params["format"])
return features.BoundingBox.new_like(inpt, output, format=params["format"]) return features.BoundingBox.wrap_like(inpt, output, format=params["format"])
class ConvertImageDtype(Transform): class ConvertImageDtype(Transform):
...@@ -30,7 +30,11 @@ class ConvertImageDtype(Transform): ...@@ -30,7 +30,11 @@ class ConvertImageDtype(Transform):
def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType: def _transform(self, inpt: features.TensorImageType, params: Dict[str, Any]) -> features.TensorImageType:
output = F.convert_image_dtype(inpt, dtype=self.dtype) output = F.convert_image_dtype(inpt, dtype=self.dtype)
return output if features.is_simple_tensor(inpt) else features.Image.new_like(inpt, output, dtype=self.dtype) # type: ignore[arg-type] return (
output
if features.is_simple_tensor(inpt)
else features.Image.wrap_like(inpt, output) # type: ignore[arg-type]
)
class ConvertColorSpace(Transform): class ConvertColorSpace(Transform):
...@@ -65,4 +69,4 @@ class ClampBoundingBoxes(Transform): ...@@ -65,4 +69,4 @@ class ClampBoundingBoxes(Transform):
def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox: def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox:
output = F.clamp_bounding_box(inpt, format=inpt.format, image_size=inpt.image_size) output = F.clamp_bounding_box(inpt, format=inpt.format, image_size=inpt.image_size)
return features.BoundingBox.new_like(inpt, output) return features.BoundingBox.wrap_like(inpt, output)
...@@ -171,4 +171,4 @@ class RemoveSmallBoundingBoxes(Transform): ...@@ -171,4 +171,4 @@ class RemoveSmallBoundingBoxes(Transform):
return dict(valid_indices=valid_indices) return dict(valid_indices=valid_indices)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.new_like(inpt, inpt[params["valid_indices"]]) return inpt.wrap_like(inpt, inpt[params["valid_indices"]])
...@@ -35,7 +35,7 @@ def erase( ...@@ -35,7 +35,7 @@ def erase(
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)): if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).new_like(inpt, output) # type: ignore[arg-type] output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output return output
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
...@@ -1409,7 +1409,7 @@ def five_crop( ...@@ -1409,7 +1409,7 @@ def five_crop(
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size) output = five_crop_image_tensor(inpt, size)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image): if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = tuple(features.Image.new_like(inpt, item) for item in output) # type: ignore[assignment] output = tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[assignment]
return output return output
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size) return five_crop_image_pil(inpt, size)
...@@ -1446,7 +1446,7 @@ def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool = ...@@ -1446,7 +1446,7 @@ def ten_crop(inpt: features.ImageTypeJIT, size: List[int], vertical_flip: bool =
if isinstance(inpt, torch.Tensor): if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip) output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if not torch.jit.is_scripting() and isinstance(inpt, features.Image): if not torch.jit.is_scripting() and isinstance(inpt, features.Image):
output = [features.Image.new_like(inpt, item) for item in output] output = [features.Image.wrap_like(inpt, item) for item in output]
return output return output
else: # isinstance(inpt, PIL.Image.Image): else: # isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip) return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment