Unverified Commit 41f9c1e6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Simple tensor -> pure tensor (#7846)

parent 4025fc5e
......@@ -25,7 +25,7 @@ from torchvision.prototype import datasets
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor
def assert_samples_equal(*args, msg=None, **kwargs):
......@@ -140,18 +140,18 @@ class TestCommon:
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_unaccompanied_simple_tensors(self, dataset_mock, config):
def test_no_unaccompanied_pure_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))
simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)}
pure_tensors = {key for key, value in sample.items() if is_pure_tensor(value)}
if simple_tensors and not any(
if pure_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
):
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors, "
f"{sequence_to_str(sorted(pure_tensors), separate_last='and ')} contained pure tensors, "
f"but didn't find any (encoded) image or video."
)
......
......@@ -18,7 +18,7 @@ from prototype_common_utils import make_label
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.v2.utils import check_type, is_simple_tensor
from torchvision.transforms.v2.utils import check_type, is_pure_tensor
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
......@@ -296,7 +296,7 @@ class TestPermuteDimensions:
value_type = type(value)
transformed_value = transformed_sample[key]
if check_type(value, (Image, is_simple_tensor, Video)):
if check_type(value, (Image, is_pure_tensor, Video)):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
......@@ -341,7 +341,7 @@ class TestTransposeDimensions:
transformed_value = transformed_sample[key]
transposed_dims = transform.dims.get(value_type)
if check_type(value, (Image, is_simple_tensor, Video)):
if check_type(value, (Image, is_pure_tensor, Video)):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
......
......@@ -29,7 +29,7 @@ from torchvision import datapoints
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw
from torchvision.transforms.v2.utils import check_type, is_pure_tensor, query_chw
def make_vanilla_tensor_images(*args, **kwargs):
......@@ -71,7 +71,7 @@ def auto_augment_adapter(transform, input, device):
if isinstance(value, (datapoints.BoundingBoxes, datapoints.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)):
elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
......@@ -101,7 +101,7 @@ def normalize_adapter(transform, input, device):
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)):
elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor)):
# normalize doesn't support integer images
value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value
......@@ -357,19 +357,19 @@ class TestSmoke:
3,
),
)
def test_simple_tensor_heuristic(flat_inputs):
def split_on_simple_tensor(to_split):
def test_pure_tensor_heuristic(flat_inputs):
def split_on_pure_tensor(to_split):
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
# 1. The first simple tensor. If none is present, this will be `None`
# 2. A list of the remaining simple tensors
# 1. The first pure tensor. If none is present, this will be `None`
# 2. A list of the remaining pure tensors
# 3. A list of all other items
simple_tensors = []
pure_tensors = []
others = []
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
# affect the splitting.
for item, inpt in zip(to_split, flat_inputs):
(simple_tensors if is_simple_tensor(inpt) else others).append(item)
return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others
(pure_tensors if is_pure_tensor(inpt) else others).append(item)
return pure_tensors[0] if pure_tensors else None, pure_tensors[1:], others
class CopyCloneTransform(transforms.Transform):
def _transform(self, inpt, params):
......@@ -385,20 +385,20 @@ def test_simple_tensor_heuristic(flat_inputs):
assert_equal(output, inpt)
return True
first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs)
first_pure_tensor_input, other_pure_tensor_inputs, other_inputs = split_on_pure_tensor(flat_inputs)
transform = CopyCloneTransform()
transformed_sample = transform(flat_inputs)
first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample)
first_pure_tensor_output, other_pure_tensor_outputs, other_outputs = split_on_pure_tensor(transformed_sample)
if first_simple_tensor_input is not None:
if first_pure_tensor_input is not None:
if other_inputs:
assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
assert not transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)
else:
assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
assert transform.was_applied(first_pure_tensor_output, first_pure_tensor_input)
for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs):
for output, inpt in zip(other_pure_tensor_outputs, other_pure_tensor_inputs):
assert not transform.was_applied(output, inpt)
for input, output in zip(other_inputs, other_outputs):
......@@ -1004,7 +1004,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image)
assert is_pure_tensor(image)
label = 1 if label_type is int else torch.tensor([1])
......@@ -1125,7 +1125,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
image = image.as_subclass(torch.Tensor)
assert is_simple_tensor(image)
assert is_pure_tensor(image)
label = torch.randint(0, 10, size=(num_boxes,))
......@@ -1146,7 +1146,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
out = t(sample)
if isinstance(to_tensor, transforms.ToTensor) and image_type is not datapoints.Image:
assert is_simple_tensor(out["image"])
assert is_pure_tensor(out["image"])
else:
assert isinstance(out["image"], datapoints.Image)
assert isinstance(out["label"], type(sample["label"]))
......
......@@ -602,7 +602,7 @@ def check_call_consistency(
raise AssertionError(
f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`is_simple_tensor` path in `_transform`."
f"`is_pure_tensor` path in `_transform`."
) from exc
assert_close(
......
......@@ -24,7 +24,7 @@ from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
......@@ -168,7 +168,7 @@ class TestKernels:
def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device)
datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
datapoint_type = datapoints.Image if is_pure_tensor(batched_input) else type(batched_input)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
data_dims = {
......@@ -333,9 +333,9 @@ class TestDispatchers:
dispatcher = script(info.dispatcher)
(image_datapoint, *other_args), kwargs = args_kwargs.load(device)
image_simple_tensor = torch.Tensor(image_datapoint)
image_pure_tensor = torch.Tensor(image_datapoint)
dispatcher(image_simple_tensor, *other_args, **kwargs)
dispatcher(image_pure_tensor, *other_args, **kwargs)
# TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke`
# replaces this test for them.
......@@ -358,11 +358,11 @@ class TestDispatchers:
script(dispatcher)
@image_sample_inputs
def test_simple_tensor_output_type(self, info, args_kwargs):
def test_pure_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_simple_tensor = image_datapoint.as_subclass(torch.Tensor)
image_pure_tensor = image_datapoint.as_subclass(torch.Tensor)
output = info.dispatcher(image_simple_tensor, *other_args, **kwargs)
output = info.dispatcher(image_pure_tensor, *other_args, **kwargs)
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor
......@@ -505,11 +505,11 @@ class TestClampBoundingBoxes:
dict(canvas_size=(1, 1)),
],
)
def test_simple_tensor_insufficient_metadata(self, metadata):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
def test_pure_tensor_insufficient_metadata(self, metadata):
pure_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")):
F.clamp_bounding_boxes(simple_tensor, **metadata)
F.clamp_bounding_boxes(pure_tensor, **metadata)
@pytest.mark.parametrize(
"metadata",
......@@ -538,11 +538,11 @@ class TestConvertFormatBoundingBoxes:
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_format_bounding_boxes(inpt, old_format)
def test_simple_tensor_insufficient_metadata(self):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
def test_pure_tensor_insufficient_metadata(self):
pure_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
F.convert_format_bounding_boxes(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
def test_datapoint_explicit_metadata(self):
datapoint = next(make_bounding_boxes())
......
......@@ -37,15 +37,15 @@ MASK = make_detection_mask(DEFAULT_SIZE)
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), True),
(
(torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
True,
),
(
(to_pil_image(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor),
True,
),
],
......
......@@ -107,7 +107,7 @@ multi_crop_skips = [
("TestDispatchers", test_name),
pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
)
for test_name in ["test_simple_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
]
multi_crop_skips.append(skip_dispatch_datapoint)
......
......@@ -9,7 +9,7 @@ from torchvision.prototype import datapoints as proto_datapoints
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor
class SimpleCopyPaste(Transform):
......@@ -109,7 +109,7 @@ class SimpleCopyPaste(Transform):
# with List[image], List[BoundingBoxes], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, datapoints.Image) or is_simple_tensor(obj):
if isinstance(obj, datapoints.Image) or is_pure_tensor(obj):
images.append(obj)
elif isinstance(obj, PIL.Image.Image):
images.append(F.to_image(obj))
......@@ -146,7 +146,7 @@ class SimpleCopyPaste(Transform):
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_pil_image(output_images[c0])
c0 += 1
elif is_simple_tensor(obj):
elif is_pure_tensor(obj):
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, datapoints.BoundingBoxes):
......
......@@ -7,7 +7,7 @@ from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size
from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_pure_tensor, query_size
class FixedSizeCrop(Transform):
......@@ -32,7 +32,7 @@ class FixedSizeCrop(Transform):
flat_inputs,
PIL.Image.Image,
datapoints.Image,
is_simple_tensor,
is_pure_tensor,
datapoints.Video,
):
raise TypeError(
......
......@@ -8,7 +8,7 @@ import torch
from torchvision import datapoints
from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor
T = TypeVar("T")
......@@ -25,7 +25,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]:
class PermuteDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None:
super().__init__()
......@@ -47,7 +47,7 @@ class PermuteDimensions(Transform):
class TransposeDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)
def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None:
super().__init__()
......
......@@ -12,7 +12,7 @@ from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform, Transform
from ._utils import _parse_labels_getter
from .utils import has_any, is_simple_tensor, query_chw, query_size
from .utils import has_any, is_pure_tensor, query_chw, query_size
class RandomErasing(_RandomApplyTransform):
......@@ -243,7 +243,7 @@ class MixUp(_BaseMixUpCutMix):
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
......@@ -310,7 +310,7 @@ class CutMix(_BaseMixUpCutMix):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
x1, y1, x2, y2 = params["box"]
......
......@@ -13,7 +13,7 @@ from torchvision.transforms.v2.functional._meta import get_size
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_simple_tensor
from .utils import check_type, is_pure_tensor
ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video]
......@@ -50,7 +50,7 @@ class _AutoAugmentBase(Transform):
(
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
is_pure_tensor,
datapoints.Video,
),
):
......
......@@ -24,7 +24,7 @@ from ._utils import (
_setup_float_or_seq,
_setup_size,
)
from .utils import get_bounding_boxes, has_all, has_any, is_simple_tensor, query_size
from .utils import get_bounding_boxes, has_all, has_any, is_pure_tensor, query_size
class RandomHorizontalFlip(_RandomApplyTransform):
......@@ -1149,7 +1149,7 @@ class RandomIoUCrop(Transform):
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_all(flat_inputs, datapoints.BoundingBoxes)
and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor)
and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_pure_tensor)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain tensor or PIL images "
......
......@@ -10,7 +10,7 @@ from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size
from .utils import get_bounding_boxes, has_any, is_simple_tensor
from .utils import get_bounding_boxes, has_any, is_pure_tensor
# TODO: do we want/need to expose this?
......@@ -75,7 +75,7 @@ class LinearTransformation(Transform):
_v1_transform_cls = _transforms.LinearTransformation
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
_transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video)
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
......@@ -264,7 +264,7 @@ class ToDtype(Transform):
if isinstance(self.dtype, torch.dtype):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype
if not is_simple_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)):
if not is_pure_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)):
return inpt
dtype: Optional[torch.dtype] = self.dtype
......@@ -281,7 +281,7 @@ class ToDtype(Transform):
'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
)
supports_scaling = is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video))
supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video))
if dtype is None:
if self.scale and supports_scaling:
warnings.warn(
......
......@@ -8,7 +8,7 @@ import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor
from torchvision.transforms.v2.utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once
from .functional._utils import _get_kernel
......@@ -55,32 +55,32 @@ class Transform(nn.Module):
return tree_unflatten(flat_outputs, spec)
def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
# Below is a heuristic on how to deal with simple tensor inputs:
# 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# Below is a heuristic on how to deal with pure tensor inputs:
# 1. Pure tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
# (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is
# 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input.
#
# This heuristic stems from two requirements:
# 1. We need to keep BC for single input simple tensors and treat them as images.
# 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface`
# 1. We need to keep BC for single input pure tensors and treat them as images.
# 2. We don't want to treat all pure tensors as images, because some datasets like `CelebA` or `Widerface`
# return supplemental numerical data as tensors that cannot be transformed as images.
#
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
# tries to transform multiple pure tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
needs_transform_list = []
transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
transform_pure_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
for inpt in flat_inputs:
needs_transform = True
if not check_type(inpt, self._transformed_types):
needs_transform = False
elif is_simple_tensor(inpt):
if transform_simple_tensor:
transform_simple_tensor = False
elif is_pure_tensor(inpt):
if transform_pure_tensor:
transform_pure_tensor = False
else:
needs_transform = False
needs_transform_list.append(needs_transform)
......
......@@ -7,7 +7,7 @@ import torch
from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.utils import is_simple_tensor
from torchvision.transforms.v2.utils import is_pure_tensor
class PILToTensor(Transform):
......@@ -35,7 +35,7 @@ class ToImage(Transform):
This transform does not support torchscript.
"""
_transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray)
_transformed_types = (is_pure_tensor, PIL.Image.Image, np.ndarray)
def _transform(
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
......@@ -65,7 +65,7 @@ class ToPILImage(Transform):
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
"""
_transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray)
_transformed_types = (is_pure_tensor, datapoints.Image, np.ndarray)
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
......
from torchvision.transforms import InterpolationMode # usort: skip
from ._utils import is_simple_tensor, register_kernel # usort: skip
from ._utils import is_pure_tensor, register_kernel # usort: skip
from ._meta import (
clamp_bounding_boxes,
......
......@@ -8,7 +8,7 @@ from torchvision.transforms import _functional_pil as _FP
from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor
from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
def get_dimensions(inpt: torch.Tensor) -> List[int]:
......@@ -203,7 +203,7 @@ def convert_format_bounding_boxes(
new_format: Optional[BoundingBoxFormat] = None,
inplace: bool = False,
) -> torch.Tensor:
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for simple tensor
# This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor
# inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
......@@ -213,9 +213,9 @@ def convert_format_bounding_boxes(
if not torch.jit.is_scripting():
_log_api_usage_once(convert_format_bounding_boxes)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting() or is_pure_tensor(inpt):
if old_format is None:
raise ValueError("For simple tensor inputs, `old_format` has to be passed.")
raise ValueError("For pure tensor inputs, `old_format` has to be passed.")
return _convert_format_bounding_boxes(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
elif isinstance(inpt, datapoints.BoundingBoxes):
if old_format is not None:
......@@ -256,10 +256,10 @@ def clamp_bounding_boxes(
if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_boxes)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if torch.jit.is_scripting() or is_pure_tensor(inpt):
if format is None or canvas_size is None:
raise ValueError("For simple tensor inputs, `format` and `canvas_size` has to be passed.")
raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.")
return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size)
elif isinstance(inpt, datapoints.BoundingBoxes):
if format is not None or canvas_size is not None:
......
......@@ -8,7 +8,7 @@ _FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]
def is_simple_tensor(inpt: Any) -> bool:
def is_pure_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint)
......
......@@ -6,7 +6,7 @@ import PIL.Image
from torchvision import datapoints
from torchvision._utils import sequence_to_str
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
......@@ -21,7 +21,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
}
if not chws:
raise TypeError("No image or video was found in the sample")
......@@ -38,7 +38,7 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
if check_type(
inpt,
(
is_simple_tensor,
is_pure_tensor,
datapoints.Image,
PIL.Image.Image,
datapoints.Video,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment