Unverified Commit 1120aa9e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

introduce heuristic for simple tensor handling of transforms v2 (#7170)

parent 1222b495
import itertools import itertools
import re
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
import torch import torch
import torchvision.prototype.transforms.utils import torchvision.prototype.transforms.utils
from common_utils import assert_equal, cpu_and_gpu from common_utils import cpu_and_gpu
from prototype_common_utils import ( from prototype_common_utils import (
assert_equal,
DEFAULT_EXTRA_DIMS, DEFAULT_EXTRA_DIMS,
make_bounding_box, make_bounding_box,
make_bounding_boxes, make_bounding_boxes,
...@@ -25,7 +26,7 @@ from prototype_common_utils import ( ...@@ -25,7 +26,7 @@ from prototype_common_utils import (
) )
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints, transforms from torchvision.prototype import datapoints, transforms
from torchvision.prototype.transforms.utils import check_type from torchvision.prototype.transforms.utils import check_type, is_simple_tensor
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
...@@ -222,6 +223,67 @@ class TestSmoke: ...@@ -222,6 +223,67 @@ class TestSmoke:
transform(input) transform(input)
@pytest.mark.parametrize(
"flat_inputs",
itertools.permutations(
[
next(make_vanilla_tensor_images()),
next(make_vanilla_tensor_images()),
next(make_pil_images()),
make_image(),
next(make_videos()),
],
3,
),
)
def test_simple_tensor_heuristic(flat_inputs):
def split_on_simple_tensor(to_split):
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
# 1. The first simple tensor. If none is present, this will be `None`
# 2. A list of the remaining simple tensors
# 3. A list of all other items
simple_tensors = []
others = []
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
# affect the splitting.
for item, inpt in zip(to_split, flat_inputs):
(simple_tensors if is_simple_tensor(inpt) else others).append(item)
return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others
class CopyCloneTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy()
@staticmethod
def was_applied(output, inpt):
identity = output is inpt
if identity:
return False
# Make sure nothing fishy is going on
assert_equal(output, inpt)
return True
first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs)
transform = CopyCloneTransform()
transformed_sample = transform(flat_inputs)
first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample)
if first_simple_tensor_input is not None:
if other_inputs:
assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
else:
assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs):
assert not transform.was_applied(output, inpt)
for input, output in zip(other_inputs, other_outputs):
assert transform.was_applied(output, input)
@pytest.mark.parametrize("p", [0.0, 1.0]) @pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip: class TestRandomHorizontalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32): def input_expected_image_tensor(self, p, dtype=torch.float32):
...@@ -1755,117 +1817,158 @@ class TestRandomResize: ...@@ -1755,117 +1817,158 @@ class TestRandomResize:
) )
@pytest.mark.parametrize( class TestToDtype:
("dtype", "expected_dtypes"), @pytest.mark.parametrize(
[ ("dtype", "expected_dtypes"),
( [
torch.float64, (
{torch.Tensor: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64}, torch.float64,
), {
( datapoints.Video: torch.float64,
{torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, datapoints.Image: torch.float64,
{torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, datapoints.BoundingBox: torch.float64,
), },
], ),
) (
def test_to_dtype(dtype, expected_dtypes): {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
sample = dict( {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"), ),
image=make_image(dtype=torch.uint8), ],
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
) )
def test_call(self, dtype, expected_dtypes):
sample = dict(
video=make_video(dtype=torch.int64),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
)
transform = transforms.ToDtype(dtype) transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample) transformed_sample = transform(sample)
for key, value in sample.items(): for key, value in sample.items():
value_type = type(value) value_type = type(value)
transformed_value = transformed_sample[key] transformed_value = transformed_sample[key]
# make sure the transformation retains the type # make sure the transformation retains the type
assert isinstance(transformed_value, value_type) assert isinstance(transformed_value, value_type)
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type] assert transformed_value.dtype is expected_dtypes[value_type]
else: else:
assert transformed_value is value assert transformed_value is value
@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((), dtype=torch.float32)
transform = transforms.ToDtype({torch.Tensor: torch.float64})
@pytest.mark.parametrize( assert transform(tensor).dtype is torch.float64
("dims", "inverse_dims"),
[ @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
( def test_plain_tensor_warning(self, other_type):
{torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: None}, with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
{torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: None}, transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64})
),
(
{torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)}, class TestPermuteDimensions:
{torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)}, @pytest.mark.parametrize(
), ("dims", "inverse_dims"),
], [
) (
def test_permute_dimensions(dims, inverse_dims): {datapoints.Image: (2, 1, 0), datapoints.Video: None},
sample = dict( {datapoints.Image: (2, 1, 0), datapoints.Video: None},
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), ),
image=make_image(), (
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), {datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)},
video=make_video(), {datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)},
str="str", ),
int=0, ],
) )
def test_call(self, dims, inverse_dims):
sample = dict(
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)
transform = transforms.PermuteDimensions(dims) transform = transforms.PermuteDimensions(dims)
transformed_sample = transform(sample) transformed_sample = transform(sample)
for key, value in sample.items(): for key, value in sample.items():
value_type = type(value) value_type = type(value)
transformed_value = transformed_sample[key] transformed_value = transformed_sample[key]
if check_type( if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
): ):
if transform.dims.get(value_type) is not None: if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value) assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor assert type(transformed_value) == torch.Tensor
else: else:
assert transformed_value is value assert transformed_value is value
@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4))
transform = transforms.PermuteDimensions(dims=(1, 2, 0))
@pytest.mark.parametrize( assert transform(tensor).shape == (3, 4, 2)
"dims",
[ @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
(-1, -2), def test_plain_tensor_warning(self, other_type):
{torch.Tensor: (-1, -2), datapoints.Image: (1, 2), datapoints.Video: None}, with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
], transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
)
def test_transpose_dimensions(dims):
sample = dict( class TestTransposeDimensions:
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), @pytest.mark.parametrize(
image=make_image(), "dims",
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), [
video=make_video(), (-1, -2),
str="str", {datapoints.Image: (1, 2), datapoints.Video: None},
int=0, ],
) )
def test_call(self, dims):
sample = dict(
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)
transform = transforms.TransposeDimensions(dims) transform = transforms.TransposeDimensions(dims)
transformed_sample = transform(sample) transformed_sample = transform(sample)
for key, value in sample.items(): for key, value in sample.items():
value_type = type(value) value_type = type(value)
transformed_value = transformed_sample[key] transformed_value = transformed_sample[key]
transposed_dims = transform.dims.get(value_type) transposed_dims = transform.dims.get(value_type)
if check_type( if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
): ):
if transposed_dims is not None: if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value) assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor assert type(transformed_value) == torch.Tensor
else: else:
assert transformed_value is value assert transformed_value is value
@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4))
transform = transforms.TransposeDimensions(dims=(0, 2))
assert transform(tensor).shape == (4, 3, 2)
@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
class TestUniformTemporalSubsample: class TestUniformTemporalSubsample:
......
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image import PIL.Image
...@@ -155,6 +156,12 @@ class ToDtype(Transform): ...@@ -155,6 +156,12 @@ class ToDtype(Transform):
super().__init__() super().__init__()
if not isinstance(dtype, dict): if not isinstance(dtype, dict):
dtype = _get_defaultdict(dtype) dtype = _get_defaultdict(dtype)
if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dtype = dtype self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
...@@ -171,6 +178,12 @@ class PermuteDimensions(Transform): ...@@ -171,6 +178,12 @@ class PermuteDimensions(Transform):
super().__init__() super().__init__()
if not isinstance(dims, dict): if not isinstance(dims, dict):
dims = _get_defaultdict(dims) dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dims = dims self.dims = dims
def _transform( def _transform(
...@@ -189,6 +202,12 @@ class TransposeDimensions(Transform): ...@@ -189,6 +202,12 @@ class TransposeDimensions(Transform):
super().__init__() super().__init__()
if not isinstance(dims, dict): if not isinstance(dims, dict):
dims = _get_defaultdict(dims) dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dims = dims self.dims = dims
def _transform( def _transform(
......
...@@ -7,7 +7,8 @@ import PIL.Image ...@@ -7,7 +7,8 @@ import PIL.Image
import torch import torch
from torch import nn from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype.transforms.utils import check_type from torchvision.prototype import datapoints
from torchvision.prototype.transforms.utils import check_type, has_any, is_simple_tensor
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -37,9 +38,35 @@ class Transform(nn.Module): ...@@ -37,9 +38,35 @@ class Transform(nn.Module):
params = self._get_params(flat_inputs) params = self._get_params(flat_inputs)
flat_outputs = [ # Below is a heuristic on how to deal with simple tensor inputs:
self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs # 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
] # (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input.
#
# This heuristic stems from two requirements:
# 1. We need to keep BC for single input simple tensors and treat them as images.
# 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface`
# return supplemental numerical data as tensors that cannot be transformed as images.
#
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
flat_outputs = []
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
for inpt in flat_inputs:
needs_transform = True
if not check_type(inpt, self._transformed_types):
needs_transform = False
elif is_simple_tensor(inpt):
if transform_simple_tensor:
transform_simple_tensor = False
else:
needs_transform = False
flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt)
return tree_unflatten(flat_outputs, spec) return tree_unflatten(flat_outputs, spec)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment