Unverified Commit 41f9c1e6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Simple tensor -> pure tensor (#7846)

parent 4025fc5e
...@@ -25,7 +25,7 @@ from torchvision.prototype import datasets ...@@ -25,7 +25,7 @@ from torchvision.prototype import datasets
from torchvision.prototype.datapoints import Label from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.transforms.v2.utils import is_simple_tensor from torchvision.transforms.v2.utils import is_pure_tensor
def assert_samples_equal(*args, msg=None, **kwargs): def assert_samples_equal(*args, msg=None, **kwargs):
...@@ -140,18 +140,18 @@ class TestCommon: ...@@ -140,18 +140,18 @@ class TestCommon:
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:")) raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))
@parametrize_dataset_mocks(DATASET_MOCKS) @parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_unaccompanied_simple_tensors(self, dataset_mock, config): def test_no_unaccompanied_pure_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset)) sample = next_consume(iter(dataset))
simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)} pure_tensors = {key for key, value in sample.items() if is_pure_tensor(value)}
if simple_tensors and not any( if pure_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values() isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
): ):
raise AssertionError( raise AssertionError(
f"The values of key(s) " f"The values of key(s) "
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors, " f"{sequence_to_str(sorted(pure_tensors), separate_last='and ')} contained pure tensors, "
f"but didn't find any (encoded) image or video." f"but didn't find any (encoded) image or video."
) )
......
...@@ -18,7 +18,7 @@ from prototype_common_utils import make_label ...@@ -18,7 +18,7 @@ from prototype_common_utils import make_label
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.v2.utils import check_type, is_simple_tensor from torchvision.transforms.v2.utils import check_type, is_pure_tensor
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
...@@ -296,7 +296,7 @@ class TestPermuteDimensions: ...@@ -296,7 +296,7 @@ class TestPermuteDimensions:
value_type = type(value) value_type = type(value)
transformed_value = transformed_sample[key] transformed_value = transformed_sample[key]
if check_type(value, (Image, is_simple_tensor, Video)): if check_type(value, (Image, is_pure_tensor, Video)):
if transform.dims.get(value_type) is not None: if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value) assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor assert type(transformed_value) == torch.Tensor
...@@ -341,7 +341,7 @@ class TestTransposeDimensions: ...@@ -341,7 +341,7 @@ class TestTransposeDimensions:
transformed_value = transformed_sample[key] transformed_value = transformed_sample[key]
transposed_dims = transform.dims.get(value_type) transposed_dims = transform.dims.get(value_type)
if check_type(value, (Image, is_simple_tensor, Video)): if check_type(value, (Image, is_pure_tensor, Video)):
if transposed_dims is not None: if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value) assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor assert type(transformed_value) == torch.Tensor
......
...@@ -29,7 +29,7 @@ from torchvision import datapoints ...@@ -29,7 +29,7 @@ from torchvision import datapoints
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw from torchvision.transforms.v2.utils import check_type, is_pure_tensor, query_chw
def make_vanilla_tensor_images(*args, **kwargs): def make_vanilla_tensor_images(*args, **kwargs):
...@@ -71,7 +71,7 @@ def auto_augment_adapter(transform, input, device): ...@@ -71,7 +71,7 @@ def auto_augment_adapter(transform, input, device):
if isinstance(value, (datapoints.BoundingBoxes, datapoints.Mask)): if isinstance(value, (datapoints.BoundingBoxes, datapoints.Mask)):
# AA transforms don't support bounding boxes or masks # AA transforms don't support bounding boxes or masks
continue continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)): elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor, PIL.Image.Image)):
if image_or_video_found: if image_or_video_found:
# AA transforms only support a single image or video # AA transforms only support a single image or video
continue continue
...@@ -101,7 +101,7 @@ def normalize_adapter(transform, input, device): ...@@ -101,7 +101,7 @@ def normalize_adapter(transform, input, device):
if isinstance(value, PIL.Image.Image): if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images # normalize doesn't support PIL images
continue continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)): elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor)):
# normalize doesn't support integer images # normalize doesn't support integer images
value = F.to_dtype(value, torch.float32, scale=True) value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value adapted_input[key] = value
...@@ -357,19 +357,19 @@ class TestSmoke: ...@@ -357,19 +357,19 @@ class TestSmoke:
3, 3,
), ),
) )
def test_simple_tensor_heuristic(flat_inputs): def test_pure_tensor_heuristic(flat_inputs):
def split_on_simple_tensor(to_split): def split_on_pure_tensor(to_split):
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts: # This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
# 1. The first simple tensor. If none is present, this will be `None` # 1. The first pure tensor. If none is present, this will be `None`
# 2. A list of the remaining simple tensors # 2. A list of the remaining pure tensors
# 3. A list of all other items # 3. A list of all other items
simple_tensors = [] pure_tensors = []
others = [] others = []
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to # Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
# affect the splitting. # affect the splitting.
for item, inpt in zip(to_split, flat_inputs): for item, inpt in zip(to_split, flat_inputs):
(simple_tensors if is_simple_tensor(inpt) else others).append(item) (pure_tensors if is_pure_tensor(inpt) else others).append(item)
return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others return pure_tensors[0] if pure_tensors else None, pure_tensors[1:], others
class CopyCloneTransform(transforms.Transform): class CopyCloneTransform(transforms.Transform):
def _transform(self, inpt, params): def _transform(self, inpt, params):
...@@ -385,20 +385,20 @@ def test_simple_tensor_heuristic(flat_inputs): ...@@ -385,20 +385,20 @@ def test_simple_tensor_heuristic(flat_inputs):
assert_equal(output, inpt) assert_equal(output, inpt)
return True return True
first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs) first_pure_tensor_input, other_pure_tensor_inputs, other_inputs = split_on_pure_tensor(flat_inputs)
transform = CopyCloneTransform() transform = CopyCloneTransform()
transformed_sample = transform(flat_inputs) transformed_sample = transform(flat_inputs)
first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample) first_pure_tensor_output, other_pure_tensor_outputs, other_outputs = split_on_pure_tensor(transformed_sample)
if first_simple_tensor_input is not None: if first_pure_tensor_input is not None:
if other_inputs: if other_inputs:
assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) assert not transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)
else: else:
assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) assert transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)
for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs): for output, inpt in zip(other_pure_tensor_outputs, other_pure_tensor_inputs):
assert not transform.was_applied(output, inpt) assert not transform.was_applied(output, inpt)
for input, output in zip(other_inputs, other_outputs): for input, output in zip(other_inputs, other_outputs):
...@@ -1004,7 +1004,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): ...@@ -1004,7 +1004,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
image = to_pil_image(image[0]) image = to_pil_image(image[0])
elif image_type is torch.Tensor: elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor) image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image) assert is_pure_tensor(image)
label = 1 if label_type is int else torch.tensor([1]) label = 1 if label_type is int else torch.tensor([1])
...@@ -1125,7 +1125,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1125,7 +1125,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
image = to_pil_image(image[0]) image = to_pil_image(image[0])
elif image_type is torch.Tensor: elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor) image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image) assert is_pure_tensor(image)
label = torch.randint(0, 10, size=(num_boxes,)) label = torch.randint(0, 10, size=(num_boxes,))
...@@ -1146,7 +1146,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1146,7 +1146,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
out = t(sample) out = t(sample)
if isinstance(to_tensor, transforms.ToTensor) and image_type is not datapoints.Image: if isinstance(to_tensor, transforms.ToTensor) and image_type is not datapoints.Image:
assert is_simple_tensor(out["image"]) assert is_pure_tensor(out["image"])
else: else:
assert isinstance(out["image"], datapoints.Image) assert isinstance(out["image"], datapoints.Image)
assert isinstance(out["label"], type(sample["label"])) assert isinstance(out["label"], type(sample["label"]))
......
...@@ -602,7 +602,7 @@ def check_call_consistency( ...@@ -602,7 +602,7 @@ def check_call_consistency(
raise AssertionError( raise AssertionError(
f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with " f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the " f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`is_simple_tensor` path in `_transform`." f"`is_pure_tensor` path in `_transform`."
) from exc ) from exc
assert_close( assert_close(
......
...@@ -24,7 +24,7 @@ from torchvision.transforms.functional import _get_perspective_coeffs ...@@ -24,7 +24,7 @@ from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.utils import is_simple_tensor from torchvision.transforms.v2.utils import is_pure_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS
...@@ -168,7 +168,7 @@ class TestKernels: ...@@ -168,7 +168,7 @@ class TestKernels:
def test_batched_vs_single(self, test_id, info, args_kwargs, device): def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device) (batched_input, *other_args), kwargs = args_kwargs.load(device)
datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input) datapoint_type = datapoints.Image if is_pure_tensor(batched_input) else type(batched_input)
# This dictionary contains the number of rightmost dimensions that contain the actual data. # This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension. # Everything to the left is considered a batch dimension.
data_dims = { data_dims = {
...@@ -333,9 +333,9 @@ class TestDispatchers: ...@@ -333,9 +333,9 @@ class TestDispatchers:
dispatcher = script(info.dispatcher) dispatcher = script(info.dispatcher)
(image_datapoint, *other_args), kwargs = args_kwargs.load(device) (image_datapoint, *other_args), kwargs = args_kwargs.load(device)
image_simple_tensor = torch.Tensor(image_datapoint) image_pure_tensor = torch.Tensor(image_datapoint)
dispatcher(image_simple_tensor, *other_args, **kwargs) dispatcher(image_pure_tensor, *other_args, **kwargs)
# TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke` # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
# replaces this test for them. # replaces this test for them.
...@@ -358,11 +358,11 @@ class TestDispatchers: ...@@ -358,11 +358,11 @@ class TestDispatchers:
script(dispatcher) script(dispatcher)
@image_sample_inputs @image_sample_inputs
def test_simple_tensor_output_type(self, info, args_kwargs): def test_pure_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load() (image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = image_datapoint.as_subclass(torch.Tensor) image_pure_tensor = image_datapoint.as_subclass(torch.Tensor)
output = info.dispatcher(image_simple_tensor, *other_args, **kwargs) output = info.dispatcher(image_pure_tensor, *other_args, **kwargs)
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
...@@ -505,11 +505,11 @@ class TestClampBoundingBoxes: ...@@ -505,11 +505,11 @@ class TestClampBoundingBoxes:
dict(canvas_size=(1, 1)), dict(canvas_size=(1, 1)),
], ],
) )
def test_simple_tensor_insufficient_metadata(self, metadata): def test_pure_tensor_insufficient_metadata(self, metadata):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) pure_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")): with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
F.clamp_bounding_boxes(simple_tensor, **metadata) F.clamp_bounding_boxes(pure_tensor, **metadata)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"metadata", "metadata",
...@@ -538,11 +538,11 @@ class TestConvertFormatBoundingBoxes: ...@@ -538,11 +538,11 @@ class TestConvertFormatBoundingBoxes:
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")): with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_format_bounding_boxes(inpt, old_format) F.convert_format_bounding_boxes(inpt, old_format)
def test_simple_tensor_insufficient_metadata(self): def test_pure_tensor_insufficient_metadata(self):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) pure_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")): with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH) F.convert_format_bounding_boxes(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
def test_datapoint_explicit_metadata(self): def test_datapoint_explicit_metadata(self):
datapoint = next(make_bounding_boxes()) datapoint = next(make_bounding_boxes())
......
...@@ -37,15 +37,15 @@ MASK = make_detection_mask(DEFAULT_SIZE) ...@@ -37,15 +37,15 @@ MASK = make_detection_mask(DEFAULT_SIZE)
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True), ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), True),
( (
(torch.Tensor(IMAGE),), (torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
True, True,
), ),
( (
(to_pil_image(IMAGE),), (to_pil_image(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
True, True,
), ),
], ],
......
...@@ -107,7 +107,7 @@ multi_crop_skips = [ ...@@ -107,7 +107,7 @@ multi_crop_skips = [
("TestDispatchers", test_name), ("TestDispatchers", test_name),
pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."), pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
) )
for test_name in ["test_simple_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"] for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
] ]
multi_crop_skips.append(skip_dispatch_datapoint) multi_crop_skips.append(skip_dispatch_datapoint)
......
...@@ -9,7 +9,7 @@ from torchvision.prototype import datapoints as proto_datapoints ...@@ -9,7 +9,7 @@ from torchvision.prototype import datapoints as proto_datapoints
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import is_simple_tensor from torchvision.transforms.v2.utils import is_pure_tensor
class SimpleCopyPaste(Transform): class SimpleCopyPaste(Transform):
...@@ -109,7 +109,7 @@ class SimpleCopyPaste(Transform): ...@@ -109,7 +109,7 @@ class SimpleCopyPaste(Transform):
# with List[image], List[BoundingBoxes], List[Mask], List[Label] # with List[image], List[BoundingBoxes], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], [] images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample: for obj in flat_sample:
if isinstance(obj, datapoints.Image) or is_simple_tensor(obj): if isinstance(obj, datapoints.Image) or is_pure_tensor(obj):
images.append(obj) images.append(obj)
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image(obj)) images.append(F.to_image(obj))
...@@ -146,7 +146,7 @@ class SimpleCopyPaste(Transform): ...@@ -146,7 +146,7 @@ class SimpleCopyPaste(Transform):
elif isinstance(obj, PIL.Image.Image): elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_pil_image(output_images[c0]) flat_sample[i] = F.to_pil_image(output_images[c0])
c0 += 1 c0 += 1
elif is_simple_tensor(obj): elif is_pure_tensor(obj):
flat_sample[i] = output_images[c0] flat_sample[i] = output_images[c0]
c0 += 1 c0 += 1
elif isinstance(obj, datapoints.BoundingBoxes): elif isinstance(obj, datapoints.BoundingBoxes):
......
...@@ -7,7 +7,7 @@ from torchvision import datapoints ...@@ -7,7 +7,7 @@ from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_pure_tensor, query_size
class FixedSizeCrop(Transform): class FixedSizeCrop(Transform):
...@@ -32,7 +32,7 @@ class FixedSizeCrop(Transform): ...@@ -32,7 +32,7 @@ class FixedSizeCrop(Transform):
flat_inputs, flat_inputs,
PIL.Image.Image, PIL.Image.Image,
datapoints.Image, datapoints.Image,
is_simple_tensor, is_pure_tensor,
datapoints.Video, datapoints.Video,
): ):
raise TypeError( raise TypeError(
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2 import Transform from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2.utils import is_simple_tensor from torchvision.transforms.v2.utils import is_pure_tensor
T = TypeVar("T") T = TypeVar("T")
...@@ -25,7 +25,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]: ...@@ -25,7 +25,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
class PermuteDimensions(Transform): class PermuteDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) _transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None: def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__() super().__init__()
...@@ -47,7 +47,7 @@ class PermuteDimensions(Transform): ...@@ -47,7 +47,7 @@ class PermuteDimensions(Transform):
class TransposeDimensions(Transform): class TransposeDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) _transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None: def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__() super().__init__()
......
...@@ -12,7 +12,7 @@ from torchvision.transforms.v2 import functional as F ...@@ -12,7 +12,7 @@ from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform, Transform from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter from ._utils import _parse_labels_getter
from .utils import has_any, is_simple_tensor, query_chw, query_size from .utils import has_any, is_pure_tensor, query_chw, query_size
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
...@@ -243,7 +243,7 @@ class MixUp(_BaseMixUpCutMix): ...@@ -243,7 +243,7 @@ class MixUp(_BaseMixUpCutMix):
if inpt is params["labels"]: if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam) return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"]) self._check_image_or_video(inpt, batch_size=params["batch_size"])
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
...@@ -310,7 +310,7 @@ class CutMix(_BaseMixUpCutMix): ...@@ -310,7 +310,7 @@ class CutMix(_BaseMixUpCutMix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]: if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"]) return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"]) self._check_image_or_video(inpt, batch_size=params["batch_size"])
x1, y1, x2, y2 = params["box"] x1, y1, x2, y2 = params["box"]
......
...@@ -13,7 +13,7 @@ from torchvision.transforms.v2.functional._meta import get_size ...@@ -13,7 +13,7 @@ from torchvision.transforms.v2.functional._meta import get_size
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from ._utils import _get_fill, _setup_fill_arg from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_simple_tensor from .utils import check_type, is_pure_tensor
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video] ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]
...@@ -50,7 +50,7 @@ class _AutoAugmentBase(Transform): ...@@ -50,7 +50,7 @@ class _AutoAugmentBase(Transform):
( (
datapoints.Image, datapoints.Image,
PIL.Image.Image, PIL.Image.Image,
is_simple_tensor, is_pure_tensor,
datapoints.Video, datapoints.Video,
), ),
): ):
......
...@@ -24,7 +24,7 @@ from ._utils import ( ...@@ -24,7 +24,7 @@ from ._utils import (
_setup_float_or_seq, _setup_float_or_seq,
_setup_size, _setup_size,
) )
from .utils import get_bounding_boxes, has_all, has_any, is_simple_tensor, query_size from .utils import get_bounding_boxes, has_all, has_any, is_pure_tensor, query_size
class RandomHorizontalFlip(_RandomApplyTransform): class RandomHorizontalFlip(_RandomApplyTransform):
...@@ -1149,7 +1149,7 @@ class RandomIoUCrop(Transform): ...@@ -1149,7 +1149,7 @@ class RandomIoUCrop(Transform):
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not ( if not (
has_all(flat_inputs, datapoints.BoundingBoxes) has_all(flat_inputs, datapoints.BoundingBoxes)
and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor) and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_pure_tensor)
): ):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain tensor or PIL images " f"{type(self).__name__}() requires input sample to contain tensor or PIL images "
......
...@@ -10,7 +10,7 @@ from torchvision import datapoints, transforms as _transforms ...@@ -10,7 +10,7 @@ from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size
from .utils import get_bounding_boxes, has_any, is_simple_tensor from .utils import get_bounding_boxes, has_any, is_pure_tensor
# TODO: do we want/need to expose this? # TODO: do we want/need to expose this?
...@@ -75,7 +75,7 @@ class LinearTransformation(Transform): ...@@ -75,7 +75,7 @@ class LinearTransformation(Transform):
_v1_transform_cls = _transforms.LinearTransformation _v1_transform_cls = _transforms.LinearTransformation
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) _transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__() super().__init__()
...@@ -264,7 +264,7 @@ class ToDtype(Transform): ...@@ -264,7 +264,7 @@ class ToDtype(Transform):
if isinstance(self.dtype, torch.dtype): if isinstance(self.dtype, torch.dtype):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype # is a simple torch.dtype
if not is_simple_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)): if not is_pure_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt return inpt
dtype: Optional[torch.dtype] = self.dtype dtype: Optional[torch.dtype] = self.dtype
...@@ -281,7 +281,7 @@ class ToDtype(Transform): ...@@ -281,7 +281,7 @@ class ToDtype(Transform):
'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.' 'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
) )
supports_scaling = is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)) supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video))
if dtype is None: if dtype is None:
if self.scale and supports_scaling: if self.scale and supports_scaling:
warnings.warn( warnings.warn(
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torch import nn from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor from torchvision.transforms.v2.utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from .functional._utils import _get_kernel from .functional._utils import _get_kernel
...@@ -55,32 +55,32 @@ class Transform(nn.Module): ...@@ -55,32 +55,32 @@ class Transform(nn.Module):
return tree_unflatten(flat_outputs, spec) return tree_unflatten(flat_outputs, spec)
def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
# Below is a heuristic on how to deal with simple tensor inputs: # Below is a heuristic on how to deal with pure tensor inputs:
# 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image # 1. Pure tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. # (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is # 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs` # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input. # of `tree_flatten`, which recurses depth-first through the input.
# #
# This heuristic stems from two requirements: # This heuristic stems from two requirements:
# 1. We need to keep BC for single input simple tensors and treat them as images. # 1. We need to keep BC for single input pure tensors and treat them as images.
# 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface` # 2. We don't want to treat all pure tensors as images, because some datasets like `CelebA` or `Widerface`
# return supplemental numerical data as tensors that cannot be transformed as images. # return supplemental numerical data as tensors that cannot be transformed as images.
# #
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone # The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images. # tries to transform multiple pure tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern. # However, this case wasn't supported by transforms v1 either, so there is no BC concern.
needs_transform_list = [] needs_transform_list = []
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) transform_pure_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
for inpt in flat_inputs: for inpt in flat_inputs:
needs_transform = True needs_transform = True
if not check_type(inpt, self._transformed_types): if not check_type(inpt, self._transformed_types):
needs_transform = False needs_transform = False
elif is_simple_tensor(inpt): elif is_pure_tensor(inpt):
if transform_simple_tensor: if transform_pure_tensor:
transform_simple_tensor = False transform_pure_tensor = False
else: else:
needs_transform = False needs_transform = False
needs_transform_list.append(needs_transform) needs_transform_list.append(needs_transform)
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.utils import is_simple_tensor from torchvision.transforms.v2.utils import is_pure_tensor
class PILToTensor(Transform): class PILToTensor(Transform):
...@@ -35,7 +35,7 @@ class ToImage(Transform): ...@@ -35,7 +35,7 @@ class ToImage(Transform):
This transform does not support torchscript. This transform does not support torchscript.
""" """
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray) _transformed_types = (is_pure_tensor, PIL.Image.Image, np.ndarray)
def _transform( def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
...@@ -65,7 +65,7 @@ class ToPILImage(Transform): ...@@ -65,7 +65,7 @@ class ToPILImage(Transform):
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
""" """
_transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray) _transformed_types = (is_pure_tensor, datapoints.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None: def __init__(self, mode: Optional[str] = None) -> None:
super().__init__() super().__init__()
......
from torchvision.transforms import InterpolationMode # usort: skip from torchvision.transforms import InterpolationMode # usort: skip
from ._utils import is_simple_tensor, register_kernel # usort: skip from ._utils import is_pure_tensor, register_kernel # usort: skip
from ._meta import ( from ._meta import (
clamp_bounding_boxes, clamp_bounding_boxes,
......
...@@ -8,7 +8,7 @@ from torchvision.transforms import _functional_pil as _FP ...@@ -8,7 +8,7 @@ from torchvision.transforms import _functional_pil as _FP
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
def get_dimensions(inpt: torch.Tensor) -> List[int]: def get_dimensions(inpt: torch.Tensor) -> List[int]:
...@@ -203,7 +203,7 @@ def convert_format_bounding_boxes( ...@@ -203,7 +203,7 @@ def convert_format_bounding_boxes(
new_format: Optional[BoundingBoxFormat] = None, new_format: Optional[BoundingBoxFormat] = None,
inplace: bool = False, inplace: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for simple tensor # This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on # inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value. # default error that would be thrown if `new_format` had no default value.
...@@ -213,9 +213,9 @@ def convert_format_bounding_boxes( ...@@ -213,9 +213,9 @@ def convert_format_bounding_boxes(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(convert_format_bounding_boxes) _log_api_usage_once(convert_format_bounding_boxes)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_pure_tensor(inpt):
if old_format is None: if old_format is None:
raise ValueError("For simple tensor inputs, `old_format` has to be passed.") raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
return _convert_format_bounding_boxes(inpt, old_format=old_format, new_format=new_format, inplace=inplace) return _convert_format_bounding_boxes(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
elif isinstance(inpt, datapoints.BoundingBoxes): elif isinstance(inpt, datapoints.BoundingBoxes):
if old_format is not None: if old_format is not None:
...@@ -256,10 +256,10 @@ def clamp_bounding_boxes( ...@@ -256,10 +256,10 @@ def clamp_bounding_boxes(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_boxes) _log_api_usage_once(clamp_bounding_boxes)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_pure_tensor(inpt):
if format is None or canvas_size is None: if format is None or canvas_size is None:
raise ValueError("For simple tensor inputs, `format` and `canvas_size` has to be passed.") raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.")
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size) return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
elif isinstance(inpt, datapoints.BoundingBoxes): elif isinstance(inpt, datapoints.BoundingBoxes):
if format is not None or canvas_size is not None: if format is not None or canvas_size is not None:
......
...@@ -8,7 +8,7 @@ _FillType = Union[int, float, Sequence[int], Sequence[float], None] ...@@ -8,7 +8,7 @@ _FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]] _FillTypeJIT = Optional[List[float]]
def is_simple_tensor(inpt: Any) -> bool: def is_pure_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint) return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)
......
...@@ -6,7 +6,7 @@ import PIL.Image ...@@ -6,7 +6,7 @@ import PIL.Image
from torchvision import datapoints from torchvision import datapoints
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
...@@ -21,7 +21,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: ...@@ -21,7 +21,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = { chws = {
tuple(get_dimensions(inpt)) tuple(get_dimensions(inpt))
for inpt in flat_inputs for inpt in flat_inputs
if check_type(inpt, (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)) if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
} }
if not chws: if not chws:
raise TypeError("No image or video was found in the sample") raise TypeError("No image or video was found in the sample")
...@@ -38,7 +38,7 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]: ...@@ -38,7 +38,7 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
if check_type( if check_type(
inpt, inpt,
( (
is_simple_tensor, is_pure_tensor,
datapoints.Image, datapoints.Image,
PIL.Image.Image, PIL.Image.Image,
datapoints.Video, datapoints.Video,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment