Unverified Commit f88ab124 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add PermuteDimensions and TransposeDimensions transforms (#6800)

* Add PermuteDimensions and TransposeDimensions transforms

* Strip Subclass info.

* Apply changes from code review.
parent 37618552
......@@ -18,10 +18,12 @@ from prototype_common_utils import (
make_masks,
make_one_hot_labels,
make_segmentation_mask,
make_video,
make_videos,
)
from torchvision.ops.boxes import box_iou
from torchvision.prototype import features, transforms
from torchvision.prototype.transforms._utils import _isinstance
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
......@@ -1826,3 +1828,74 @@ def test_to_dtype(dtype, expected_dtypes):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value
@pytest.mark.parametrize(
("dims", "inverse_dims"),
[
(
{torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: None},
{torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: None},
),
(
{torch.Tensor: (1, 2, 0), features.Image: (2, 1, 0), features.Video: (1, 2, 3, 0)},
{torch.Tensor: (2, 0, 1), features.Image: (2, 1, 0), features.Video: (3, 0, 1, 2)},
),
],
)
def test_permute_dimensions(dims, inverse_dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)
transform = transforms.PermuteDimensions(dims)
transformed_sample = transform(sample)
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
@pytest.mark.parametrize(
"dims",
[
(-1, -2),
{torch.Tensor: (-1, -2), features.Image: (1, 2), features.Video: None},
],
)
def test_transpose_dimensions(dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=features.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)
transform = transforms.TransposeDimensions(dims)
transformed_sample = transform(sample)
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
transposed_dims = transform.dims.get(value_type)
if _isinstance(value, (features.Image, features.is_simple_tensor, features.Video)):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
......@@ -40,7 +40,17 @@ from ._geometry import (
TenCrop,
)
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, RemoveSmallBoundingBoxes, ToDtype
from ._misc import (
GaussianBlur,
Identity,
Lambda,
LinearTransformation,
Normalize,
PermuteDimensions,
RemoveSmallBoundingBoxes,
ToDtype,
TransposeDimensions,
)
from ._type_conversion import DecodeImage, LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip
import functools
from collections import defaultdict
from typing import Any, Callable, Dict, List, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image
......@@ -9,7 +7,7 @@ from torchvision.ops import remove_small_boxes
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from ._utils import _setup_float_or_seq, _setup_size, has_any, query_bounding_box
from ._utils import _get_defaultdict, _setup_float_or_seq, _setup_size, has_any, query_bounding_box
class Identity(Transform):
......@@ -145,15 +143,10 @@ class GaussianBlur(Transform):
class ToDtype(Transform):
_transformed_types = (torch.Tensor,)
def _default_dtype(self, dtype: torch.dtype) -> torch.dtype:
return dtype
def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) -> None:
super().__init__()
if not isinstance(dtype, dict):
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: dtype)`
dtype = defaultdict(functools.partial(self._default_dtype, dtype))
dtype = _get_defaultdict(dtype)
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
......@@ -163,6 +156,42 @@ class ToDtype(Transform):
return inpt.to(dtype=dtype)
class PermuteDimensions(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
self.dims = dims
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt.permute(*dims)
class TransposeDimensions(Transform):
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
self.dims = dims
def _transform(
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
dims = self.dims[type(inpt)]
if dims is None:
return inpt.as_subclass(torch.Tensor)
return inpt.transpose(*dims)
class RemoveSmallBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel)
......
import functools
import numbers
from collections import defaultdict
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image
......@@ -42,8 +42,17 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
raise TypeError("Got inappropriate fill arg")
def _default_fill(fill: FillType) -> FillType:
return fill
T = TypeVar("T")
def _default_arg(value: T) -> T:
return value
def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
......@@ -52,9 +61,7 @@ def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, F
if isinstance(fill, dict):
return fill
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: fill)`
return defaultdict(functools.partial(_default_fill, fill))
return _get_defaultdict(fill)
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment