Unverified Commit f88ab124 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add PermuteDimensions and TransposeDimensions transforms (#6800)

* Add PermuteDimensions and TransposeDimensions transforms

* Strip Subclass info.

* Apply changes from code review.
parent 37618552
...@@ -18,10 +18,12 @@ from prototype_common_utils import ( ...@@ -18,10 +18,12 @@ from prototype_common_utils import (
make_masks, make_masks,
make_one_hot_labels, make_one_hot_labels,
make_segmentation_mask, make_segmentation_mask,
make_video,
make_videos, make_videos,
) )
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms from torchvision.prototype import features, transforms
from torchvision.prototype.transforms._utils import _isinstance
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
...@@ -1826,3 +1828,74 @@ def test_to_dtype(dtype, expected_dtypes): ...@@ -1826,3 +1828,74 @@ def test_to_dtype(dtype, expected_dtypes):
assert transformed_value.dtype is expected_dtypes[value_type] assert transformed_value.dtype is expected_dtypes[value_type]
else: else:
assert transformed_value is value assert transformed_value is value
@pytest.mark.parametrize(
("dims", "inverse_dims"),
[
(
{torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: None},
{torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: None},
),
(
{torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: (1, 2, 3, 0)},
{torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: (3, 0, 1, 2)},
),
],
)
def test_permute_dimensions(dims, inverse_dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)
transform = transforms.PermuteDimensions(dims)
transformed_sample = transform(sample)
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
@pytest.mark.parametrize(
"dims",
[
(-1, -2),
{torch.Tensor: (-1, -2), features.Image: (1, 2), features.Video: None},
],
)
def test_transpose_dimensions(dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)
transform = transforms.TransposeDimensions(dims)
transformed_sample = transform(sample)
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
transposed_dims = transform.dims.get(value_type)
if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
...@@ -40,7 +40,17 @@ from ._geometry import ( ...@@ -40,7 +40,17 @@ from ._geometry import (
TenCrop, TenCrop,
) )
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype from ._misc import (
GaussianBlur,
Identity,
Lambda,
LinearTransformation,
Normalize,
PermuteDimensions,
RemoveSmallBoundingBoxes,
ToDtype,
TransposeDimensions,
)
from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip
import functools from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from collections import defaultdict
from typing import Any, Callable, Dict, List, Sequence, Type, Union
import PIL.Image import PIL.Image
...@@ -9,7 +7,7 @@ from torchvision.ops import remove_small_boxes ...@@ -9,7 +7,7 @@ from torchvision.ops import remove_small_boxes
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from ._utils import _setup_float_or_seq, _setup_size, has_any, query_bounding_box from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size, has_any, query_bounding_box
class Identity(Transform): class Identity(Transform):
...@@ -145,15 +143,10 @@ class GaussianBlur(Transform): ...@@ -145,15 +143,10 @@ class GaussianBlur(Transform):
class ToDtype(Transform): class ToDtype(Transform):
_transformed_types = (torch.Tensor,) _transformed_types = (torch.Tensor,)
def _default_dtype(self, dtype: torch.dtype) -> torch.dtype: def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None:
return dtype
def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
super().__init__() super().__init__()
if not isinstance(dtype, dict): if not isinstance(dtype, dict):
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. dtype = _get_defaultdict(dtype)
# If it were possible, we could replace this with `defaultdict(lambda: dtype)`
dtype = defaultdict(functools.partial(self._default_dtype, dtype))
self.dtype = dtype self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
...@@ -163,6 +156,42 @@ class ToDtype(Transform): ...@@ -163,6 +156,42 @@ class ToDtype(Transform):
return inpt.to(dtype=dtype) return inpt.to(dtype=dtype)
class PermuteDimensions(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
self.dims = dims
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt.permute(*dims)
class TransposeDimensions(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
self.dims = dims
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt.transpose(*dims)
class RemoveSmallBoundingBoxes(Transform): class RemoveSmallBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel) _transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel)
......
import functools import functools
import numbers import numbers
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, Union from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image import PIL.Image
...@@ -42,8 +42,17 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: ...@@ -42,8 +42,17 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
raise TypeError("Got inappropriate fill arg") raise TypeError("Got inappropriate fill arg")
def _default_fill(fill: FillType) -> FillType: T = TypeVar("T")
return fill
def _default_arg(value: T) -> T:
return value
def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
...@@ -52,9 +61,7 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F ...@@ -52,9 +61,7 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F
if isinstance(fill, dict): if isinstance(fill, dict):
return fill return fill
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle. return _get_defaultdict(fill)
# If it were possible, we could replace this with `defaultdict(lambda: fill)`
return defaultdict(functools.partial(_default_fill, fill))
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment