Unverified Commit d5f4cc38 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)

parent b9447fdd
...@@ -7,11 +7,11 @@ import torch ...@@ -7,11 +7,11 @@ import torch
from common_utils import assert_equal from common_utils import assert_equal
from prototype_common_utils import make_label from prototype_common_utils import make_label
from torchvision.prototype import transforms, tv_tensors
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2._utils import check_type, is_pure_tensor from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from transforms_v2_legacy_utils import ( from transforms_v2_legacy_utils import (
DEFAULT_EXTRA_DIMS, DEFAULT_EXTRA_DIMS,
make_bounding_boxes, make_bounding_boxes,
...@@ -51,7 +51,7 @@ class TestSimpleCopyPaste: ...@@ -51,7 +51,7 @@ class TestSimpleCopyPaste:
# images, batch size = 2 # images, batch size = 2
self.create_fake_image(mocker, Image), self.create_fake_image(mocker, Image),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=datapoints.Label), mocker.MagicMock(spec=tv_tensors.Label),
mocker.MagicMock(spec=BoundingBoxes), mocker.MagicMock(spec=BoundingBoxes),
mocker.MagicMock(spec=Mask), mocker.MagicMock(spec=Mask),
# labels, bboxes, masks # labels, bboxes, masks
...@@ -63,7 +63,7 @@ class TestSimpleCopyPaste: ...@@ -63,7 +63,7 @@ class TestSimpleCopyPaste:
transform._extract_image_targets(flat_sample) transform._extract_image_targets(flat_sample)
@pytest.mark.parametrize("image_type", [Image, PIL.Image.Image, torch.Tensor]) @pytest.mark.parametrize("image_type", [Image, PIL.Image.Image, torch.Tensor])
@pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel]) @pytest.mark.parametrize("label_type", [tv_tensors.Label, tv_tensors.OneHotLabel])
def test__extract_image_targets(self, image_type, label_type, mocker): def test__extract_image_targets(self, image_type, label_type, mocker):
transform = transforms.SimpleCopyPaste() transform = transforms.SimpleCopyPaste()
...@@ -101,7 +101,7 @@ class TestSimpleCopyPaste: ...@@ -101,7 +101,7 @@ class TestSimpleCopyPaste:
assert isinstance(target[key], type_) assert isinstance(target[key], type_)
assert target[key] in flat_sample assert target[key] in flat_sample
@pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel]) @pytest.mark.parametrize("label_type", [tv_tensors.Label, tv_tensors.OneHotLabel])
def test__copy_paste(self, label_type): def test__copy_paste(self, label_type):
image = 2 * torch.ones(3, 32, 32) image = 2 * torch.ones(3, 32, 32)
masks = torch.zeros(2, 32, 32) masks = torch.zeros(2, 32, 32)
...@@ -111,7 +111,7 @@ class TestSimpleCopyPaste: ...@@ -111,7 +111,7 @@ class TestSimpleCopyPaste:
blending = True blending = True
resize_interpolation = InterpolationMode.BILINEAR resize_interpolation = InterpolationMode.BILINEAR
antialias = None antialias = None
if label_type == datapoints.OneHotLabel: if label_type == tv_tensors.OneHotLabel:
labels = torch.nn.functional.one_hot(labels, num_classes=5) labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = { target = {
"boxes": BoundingBoxes( "boxes": BoundingBoxes(
...@@ -126,7 +126,7 @@ class TestSimpleCopyPaste: ...@@ -126,7 +126,7 @@ class TestSimpleCopyPaste:
paste_masks[0, 13:19, 12:18] = 1 paste_masks[0, 13:19, 12:18] = 1
paste_masks[1, 15:19, 1:8] = 1 paste_masks[1, 15:19, 1:8] = 1
paste_labels = torch.tensor([3, 4]) paste_labels = torch.tensor([3, 4])
if label_type == datapoints.OneHotLabel: if label_type == tv_tensors.OneHotLabel:
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = { paste_target = {
"boxes": BoundingBoxes( "boxes": BoundingBoxes(
...@@ -148,7 +148,7 @@ class TestSimpleCopyPaste: ...@@ -148,7 +148,7 @@ class TestSimpleCopyPaste:
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"]) torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
expected_labels = torch.tensor([1, 2, 3, 4]) expected_labels = torch.tensor([1, 2, 3, 4])
if label_type == datapoints.OneHotLabel: if label_type == tv_tensors.OneHotLabel:
expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5) expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5)
torch.testing.assert_close(output_target["labels"], label_type(expected_labels)) torch.testing.assert_close(output_target["labels"], label_type(expected_labels))
...@@ -258,10 +258,10 @@ class TestFixedSizeCrop: ...@@ -258,10 +258,10 @@ class TestFixedSizeCrop:
class TestLabelToOneHot: class TestLabelToOneHot:
def test__transform(self): def test__transform(self):
categories = ["apple", "pear", "pineapple"] categories = ["apple", "pear", "pineapple"]
labels = datapoints.Label(torch.tensor([0, 1, 2, 1]), categories=categories) labels = tv_tensors.Label(torch.tensor([0, 1, 2, 1]), categories=categories)
transform = transforms.LabelToOneHot() transform = transforms.LabelToOneHot()
ohe_labels = transform(labels) ohe_labels = transform(labels)
assert isinstance(ohe_labels, datapoints.OneHotLabel) assert isinstance(ohe_labels, tv_tensors.OneHotLabel)
assert ohe_labels.shape == (4, 3) assert ohe_labels.shape == (4, 3)
assert ohe_labels.categories == labels.categories == categories assert ohe_labels.categories == labels.categories == categories
...@@ -383,7 +383,7 @@ det_transforms = import_transforms_from_references("detection") ...@@ -383,7 +383,7 @@ det_transforms = import_transforms_from_references("detection")
def test_fixed_sized_crop_against_detection_reference(): def test_fixed_sized_crop_against_detection_reference():
def make_datapoints(): def make_tv_tensors():
size = (600, 800) size = (600, 800)
num_objects = 22 num_objects = 22
...@@ -405,19 +405,19 @@ def test_fixed_sized_crop_against_detection_reference(): ...@@ -405,19 +405,19 @@ def test_fixed_sized_crop_against_detection_reference():
yield (tensor_image, target) yield (tensor_image, target)
datapoint_image = make_image(size=size, color_space="RGB") tv_tensor_image = make_image(size=size, color_space="RGB")
target = { target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
} }
yield (datapoint_image, target) yield (tv_tensor_image, target)
t = transforms.FixedSizeCrop((1024, 1024), fill=0) t = transforms.FixedSizeCrop((1024, 1024), fill=0)
t_ref = det_transforms.FixedSizeCrop((1024, 1024), fill=0) t_ref = det_transforms.FixedSizeCrop((1024, 1024), fill=0)
for dp in make_datapoints(): for dp in make_tv_tensors():
# We should use prototype transform first as reference transform performs inplace target update # We should use prototype transform first as reference transform performs inplace target update
torch.manual_seed(12) torch.manual_seed(12)
output = t(dp) output = t(dp)
......
...@@ -13,7 +13,7 @@ import torchvision.transforms.v2 as transforms ...@@ -13,7 +13,7 @@ import torchvision.transforms.v2 as transforms
from common_utils import assert_equal, cpu_and_cuda from common_utils import assert_equal, cpu_and_cuda
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
...@@ -66,10 +66,10 @@ def auto_augment_adapter(transform, input, device): ...@@ -66,10 +66,10 @@ def auto_augment_adapter(transform, input, device):
adapted_input = {} adapted_input = {}
image_or_video_found = False image_or_video_found = False
for key, value in input.items(): for key, value in input.items():
if isinstance(value, (datapoints.BoundingBoxes, datapoints.Mask)): if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
# AA transforms don't support bounding boxes or masks # AA transforms don't support bounding boxes or masks
continue continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor, PIL.Image.Image)): elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)):
if image_or_video_found: if image_or_video_found:
# AA transforms only support a single image or video # AA transforms only support a single image or video
continue continue
...@@ -99,7 +99,7 @@ def normalize_adapter(transform, input, device): ...@@ -99,7 +99,7 @@ def normalize_adapter(transform, input, device):
if isinstance(value, PIL.Image.Image): if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images # normalize doesn't support PIL images
continue continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor)): elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor)):
# normalize doesn't support integer images # normalize doesn't support integer images
value = F.to_dtype(value, torch.float32, scale=True) value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value adapted_input[key] = value
...@@ -142,7 +142,7 @@ class TestSmoke: ...@@ -142,7 +142,7 @@ class TestSmoke:
(transforms.Resize([16, 16], antialias=True), None), (transforms.Resize([16, 16], antialias=True), None),
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None), (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
(transforms.ClampBoundingBoxes(), None), (transforms.ClampBoundingBoxes(), None),
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None), (transforms.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertImageDtype(), None), (transforms.ConvertImageDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None), (transforms.GaussianBlur(kernel_size=3), None),
( (
...@@ -178,19 +178,19 @@ class TestSmoke: ...@@ -178,19 +178,19 @@ class TestSmoke:
canvas_size = F.get_size(image_or_video) canvas_size = F.get_size(image_or_video)
input = dict( input = dict(
image_or_video=image_or_video, image_or_video=image_or_video,
image_datapoint=make_image(size=canvas_size), image_tv_tensor=make_image(size=canvas_size),
video_datapoint=make_video(size=canvas_size), video_tv_tensor=make_video(size=canvas_size),
image_pil=next(make_pil_images(sizes=[canvas_size], color_spaces=["RGB"])), image_pil=next(make_pil_images(sizes=[canvas_size], color_spaces=["RGB"])),
bounding_boxes_xyxy=make_bounding_boxes( bounding_boxes_xyxy=make_bounding_boxes(
format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(3,) format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(3,)
), ),
bounding_boxes_xywh=make_bounding_boxes( bounding_boxes_xywh=make_bounding_boxes(
format=datapoints.BoundingBoxFormat.XYWH, canvas_size=canvas_size, batch_dims=(4,) format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=canvas_size, batch_dims=(4,)
), ),
bounding_boxes_cxcywh=make_bounding_boxes( bounding_boxes_cxcywh=make_bounding_boxes(
format=datapoints.BoundingBoxFormat.CXCYWH, canvas_size=canvas_size, batch_dims=(5,) format=tv_tensors.BoundingBoxFormat.CXCYWH, canvas_size=canvas_size, batch_dims=(5,)
), ),
bounding_boxes_degenerate_xyxy=datapoints.BoundingBoxes( bounding_boxes_degenerate_xyxy=tv_tensors.BoundingBoxes(
[ [
[0, 0, 0, 0], # no height or width [0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height [0, 0, 0, 1], # no height
...@@ -199,10 +199,10 @@ class TestSmoke: ...@@ -199,10 +199,10 @@ class TestSmoke:
[0, 2, 1, 1], # x1 < x2, y1 > y2 [0, 2, 1, 1], # x1 < x2, y1 > y2
[2, 2, 1, 1], # x1 > x2, y1 > y2 [2, 2, 1, 1], # x1 > x2, y1 > y2
], ],
format=datapoints.BoundingBoxFormat.XYXY, format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=canvas_size, canvas_size=canvas_size,
), ),
bounding_boxes_degenerate_xywh=datapoints.BoundingBoxes( bounding_boxes_degenerate_xywh=tv_tensors.BoundingBoxes(
[ [
[0, 0, 0, 0], # no height or width [0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height [0, 0, 0, 1], # no height
...@@ -211,10 +211,10 @@ class TestSmoke: ...@@ -211,10 +211,10 @@ class TestSmoke:
[0, 0, -1, 1], # negative width [0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width [0, 0, -1, -1], # negative height and width
], ],
format=datapoints.BoundingBoxFormat.XYWH, format=tv_tensors.BoundingBoxFormat.XYWH,
canvas_size=canvas_size, canvas_size=canvas_size,
), ),
bounding_boxes_degenerate_cxcywh=datapoints.BoundingBoxes( bounding_boxes_degenerate_cxcywh=tv_tensors.BoundingBoxes(
[ [
[0, 0, 0, 0], # no height or width [0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height [0, 0, 0, 1], # no height
...@@ -223,7 +223,7 @@ class TestSmoke: ...@@ -223,7 +223,7 @@ class TestSmoke:
[0, 0, -1, 1], # negative width [0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width [0, 0, -1, -1], # negative height and width
], ],
format=datapoints.BoundingBoxFormat.CXCYWH, format=tv_tensors.BoundingBoxFormat.CXCYWH,
canvas_size=canvas_size, canvas_size=canvas_size,
), ),
detection_mask=make_detection_mask(size=canvas_size), detection_mask=make_detection_mask(size=canvas_size),
...@@ -262,7 +262,7 @@ class TestSmoke: ...@@ -262,7 +262,7 @@ class TestSmoke:
else: else:
assert output_item is input_item assert output_item is input_item
if isinstance(input_item, datapoints.BoundingBoxes) and not isinstance( if isinstance(input_item, tv_tensors.BoundingBoxes) and not isinstance(
transform, transforms.ConvertBoundingBoxFormat transform, transforms.ConvertBoundingBoxFormat
): ):
assert output_item.format == input_item.format assert output_item.format == input_item.format
...@@ -270,9 +270,9 @@ class TestSmoke: ...@@ -270,9 +270,9 @@ class TestSmoke:
# Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future # Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
# transform that does this), back into a valid one. # transform that does this), back into a valid one.
# TODO: we should test that against all degenerate boxes above # TODO: we should test that against all degenerate boxes above
for format in list(datapoints.BoundingBoxFormat): for format in list(tv_tensors.BoundingBoxFormat):
sample = dict( sample = dict(
boxes=datapoints.BoundingBoxes([[0, 0, 0, 0]], format=format, canvas_size=(224, 244)), boxes=tv_tensors.BoundingBoxes([[0, 0, 0, 0]], format=format, canvas_size=(224, 244)),
labels=torch.tensor([3]), labels=torch.tensor([3]),
) )
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4) assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
...@@ -652,7 +652,7 @@ class TestRandomErasing: ...@@ -652,7 +652,7 @@ class TestRandomErasing:
class TestTransform: class TestTransform:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], [torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
) )
def test_check_transformed_types(self, inpt_type, mocker): def test_check_transformed_types(self, inpt_type, mocker):
# This test ensures that we correctly handle which types to transform and which to bypass # This test ensures that we correctly handle which types to transform and which to bypass
...@@ -670,7 +670,7 @@ class TestTransform: ...@@ -670,7 +670,7 @@ class TestTransform:
class TestToImage: class TestToImage:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], [torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch( fn = mocker.patch(
...@@ -681,7 +681,7 @@ class TestToImage: ...@@ -681,7 +681,7 @@ class TestToImage:
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImage() transform = transforms.ToImage()
transform(inpt) transform(inpt)
if inpt_type in (datapoints.BoundingBoxes, datapoints.Image, str, int): if inpt_type in (tv_tensors.BoundingBoxes, tv_tensors.Image, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt) fn.assert_called_once_with(inpt)
...@@ -690,7 +690,7 @@ class TestToImage: ...@@ -690,7 +690,7 @@ class TestToImage:
class TestToPILImage: class TestToPILImage:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], [torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.v2.functional.to_pil_image") fn = mocker.patch("torchvision.transforms.v2.functional.to_pil_image")
...@@ -698,7 +698,7 @@ class TestToPILImage: ...@@ -698,7 +698,7 @@ class TestToPILImage:
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToPILImage() transform = transforms.ToPILImage()
transform(inpt) transform(inpt)
if inpt_type in (PIL.Image.Image, datapoints.BoundingBoxes, str, int): if inpt_type in (PIL.Image.Image, tv_tensors.BoundingBoxes, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt, mode=transform.mode) fn.assert_called_once_with(inpt, mode=transform.mode)
...@@ -707,7 +707,7 @@ class TestToPILImage: ...@@ -707,7 +707,7 @@ class TestToPILImage:
class TestToTensor: class TestToTensor:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inpt_type", "inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], [torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_tensor") fn = mocker.patch("torchvision.transforms.functional.to_tensor")
...@@ -716,7 +716,7 @@ class TestToTensor: ...@@ -716,7 +716,7 @@ class TestToTensor:
with pytest.warns(UserWarning, match="deprecated and will be removed"): with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToTensor() transform = transforms.ToTensor()
transform(inpt) transform(inpt)
if inpt_type in (datapoints.Image, torch.Tensor, datapoints.BoundingBoxes, str, int): if inpt_type in (tv_tensors.Image, torch.Tensor, tv_tensors.BoundingBoxes, str, int):
assert fn.call_count == 0 assert fn.call_count == 0
else: else:
fn.assert_called_once_with(inpt) fn.assert_called_once_with(inpt)
...@@ -757,7 +757,7 @@ class TestRandomIoUCrop: ...@@ -757,7 +757,7 @@ class TestRandomIoUCrop:
def test__get_params(self, device, options): def test__get_params(self, device, options):
orig_h, orig_w = size = (24, 32) orig_h, orig_w = size = (24, 32)
image = make_image(size) image = make_image(size)
bboxes = datapoints.BoundingBoxes( bboxes = tv_tensors.BoundingBoxes(
torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
format="XYXY", format="XYXY",
canvas_size=size, canvas_size=size,
...@@ -792,8 +792,8 @@ class TestRandomIoUCrop: ...@@ -792,8 +792,8 @@ class TestRandomIoUCrop:
def test__transform_empty_params(self, mocker): def test__transform_empty_params(self, mocker):
transform = transforms.RandomIoUCrop(sampler_options=[2.0]) transform = transforms.RandomIoUCrop(sampler_options=[2.0])
image = datapoints.Image(torch.rand(1, 3, 4, 4)) image = tv_tensors.Image(torch.rand(1, 3, 4, 4))
bboxes = datapoints.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4)) bboxes = tv_tensors.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4))
label = torch.tensor([1]) label = torch.tensor([1])
sample = [image, bboxes, label] sample = [image, bboxes, label]
# Let's mock transform._get_params to control the output: # Let's mock transform._get_params to control the output:
...@@ -827,11 +827,11 @@ class TestRandomIoUCrop: ...@@ -827,11 +827,11 @@ class TestRandomIoUCrop:
# check number of bboxes vs number of labels: # check number of bboxes vs number of labels:
output_bboxes = output[1] output_bboxes = output[1]
assert isinstance(output_bboxes, datapoints.BoundingBoxes) assert isinstance(output_bboxes, tv_tensors.BoundingBoxes)
assert (output_bboxes[~is_within_crop_area] == 0).all() assert (output_bboxes[~is_within_crop_area] == 0).all()
output_masks = output[2] output_masks = output[2]
assert isinstance(output_masks, datapoints.Mask) assert isinstance(output_masks, tv_tensors.Mask)
class TestScaleJitter: class TestScaleJitter:
...@@ -899,7 +899,7 @@ class TestLinearTransformation: ...@@ -899,7 +899,7 @@ class TestLinearTransformation:
[ [
122 * torch.ones(1, 3, 8, 8), 122 * torch.ones(1, 3, 8, 8),
122.0 * torch.ones(1, 3, 8, 8), 122.0 * torch.ones(1, 3, 8, 8),
datapoints.Image(122 * torch.ones(1, 3, 8, 8)), tv_tensors.Image(122 * torch.ones(1, 3, 8, 8)),
PIL.Image.new("RGB", (8, 8), (122, 122, 122)), PIL.Image.new("RGB", (8, 8), (122, 122, 122)),
], ],
) )
...@@ -941,7 +941,7 @@ class TestUniformTemporalSubsample: ...@@ -941,7 +941,7 @@ class TestUniformTemporalSubsample:
[ [
torch.zeros(10, 3, 8, 8), torch.zeros(10, 3, 8, 8),
torch.zeros(1, 10, 3, 8, 8), torch.zeros(1, 10, 3, 8, 8),
datapoints.Video(torch.zeros(1, 10, 3, 8, 8)), tv_tensors.Video(torch.zeros(1, 10, 3, 8, 8)),
], ],
) )
def test__transform(self, inpt): def test__transform(self, inpt):
...@@ -971,12 +971,12 @@ def test_antialias_warning(): ...@@ -971,12 +971,12 @@ def test_antialias_warning():
transforms.RandomResize(10, 20)(tensor_img) transforms.RandomResize(10, 20)(tensor_img)
with pytest.warns(UserWarning, match=match): with pytest.warns(UserWarning, match=match):
F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20)) F.resized_crop(tv_tensors.Image(tensor_img), 0, 0, 10, 10, (20, 20))
with pytest.warns(UserWarning, match=match): with pytest.warns(UserWarning, match=match):
F.resize(datapoints.Video(tensor_video), (20, 20)) F.resize(tv_tensors.Video(tensor_video), (20, 20))
with pytest.warns(UserWarning, match=match): with pytest.warns(UserWarning, match=match):
F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20)) F.resized_crop(tv_tensors.Video(tensor_video), 0, 0, 10, 10, (20, 20))
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("error") warnings.simplefilter("error")
...@@ -990,17 +990,17 @@ def test_antialias_warning(): ...@@ -990,17 +990,17 @@ def test_antialias_warning():
transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img) transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
transforms.RandomResize(10, 20, antialias=True)(tensor_img) transforms.RandomResize(10, 20, antialias=True)(tensor_img)
F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20), antialias=True) F.resized_crop(tv_tensors.Image(tensor_img), 0, 0, 10, 10, (20, 20), antialias=True)
F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20), antialias=True) F.resized_crop(tv_tensors.Video(tensor_video), 0, 0, 10, 10, (20, 20), antialias=True)
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, int)) @pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple)) @pytest.mark.parametrize("dataset_return_type", (dict, tuple))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage)) @pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8)) image = tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8))
if image_type is PIL.Image: if image_type is PIL.Image:
image = to_pil_image(image[0]) image = to_pil_image(image[0])
elif image_type is torch.Tensor: elif image_type is torch.Tensor:
...@@ -1056,7 +1056,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): ...@@ -1056,7 +1056,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
assert out_label == label assert out_label == label
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) @pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage)) @pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
@pytest.mark.parametrize("sanitize", (True, False)) @pytest.mark.parametrize("sanitize", (True, False))
...@@ -1082,7 +1082,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1082,7 +1082,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
# leaving FixedSizeCrop in prototype for now, and it expects Label # leaving FixedSizeCrop in prototype for now, and it expects Label
# classes which we won't release yet. # classes which we won't release yet.
# transforms.FixedSizeCrop( # transforms.FixedSizeCrop(
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}) # size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {tv_tensors.Mask: 0})
# ), # ),
transforms.RandomCrop((1024, 1024), pad_if_needed=True), transforms.RandomCrop((1024, 1024), pad_if_needed=True),
transforms.RandomHorizontalFlip(p=1), transforms.RandomHorizontalFlip(p=1),
...@@ -1101,7 +1101,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1101,7 +1101,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
elif data_augmentation == "ssd": elif data_augmentation == "ssd":
t = [ t = [
transforms.RandomPhotometricDistort(p=1), transforms.RandomPhotometricDistort(p=1),
transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), datapoints.Mask: 0}, p=1), transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), tv_tensors.Mask: 0}, p=1),
transforms.RandomIoUCrop(), transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1), transforms.RandomHorizontalFlip(p=1),
to_tensor, to_tensor,
...@@ -1121,7 +1121,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1121,7 +1121,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
num_boxes = 5 num_boxes = 5
H = W = 250 H = W = 250
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)) image = tv_tensors.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8))
if image_type is PIL.Image: if image_type is PIL.Image:
image = to_pil_image(image[0]) image = to_pil_image(image[0])
elif image_type is torch.Tensor: elif image_type is torch.Tensor:
...@@ -1133,9 +1133,9 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1133,9 +1133,9 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4)) boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
boxes[:, 2:] += boxes[:, :2] boxes[:, 2:] += boxes[:, :2]
boxes = boxes.clamp(min=0, max=min(H, W)) boxes = boxes.clamp(min=0, max=min(H, W))
boxes = datapoints.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W)) boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))
masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8)) masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8))
sample = { sample = {
"image": image, "image": image,
...@@ -1146,10 +1146,10 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1146,10 +1146,10 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
out = t(sample) out = t(sample)
if isinstance(to_tensor, transforms.ToTensor) and image_type is not datapoints.Image: if isinstance(to_tensor, transforms.ToTensor) and image_type is not tv_tensors.Image:
assert is_pure_tensor(out["image"]) assert is_pure_tensor(out["image"])
else: else:
assert isinstance(out["image"], datapoints.Image) assert isinstance(out["image"], tv_tensors.Image)
assert isinstance(out["label"], type(sample["label"])) assert isinstance(out["label"], type(sample["label"]))
num_boxes_expected = { num_boxes_expected = {
...@@ -1204,13 +1204,13 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): ...@@ -1204,13 +1204,13 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
boxes = torch.tensor(boxes) boxes = torch.tensor(boxes)
labels = torch.arange(boxes.shape[0]) labels = torch.arange(boxes.shape[0])
boxes = datapoints.BoundingBoxes( boxes = tv_tensors.BoundingBoxes(
boxes, boxes,
format=datapoints.BoundingBoxFormat.XYXY, format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(H, W), canvas_size=(H, W),
) )
masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W))) masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
whatever = torch.rand(10) whatever = torch.rand(10)
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8) input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
sample = { sample = {
...@@ -1244,8 +1244,8 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): ...@@ -1244,8 +1244,8 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_image is input_img assert out_image is input_img
assert out_whatever is whatever assert out_whatever is whatever
assert isinstance(out_boxes, datapoints.BoundingBoxes) assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
assert isinstance(out_masks, datapoints.Mask) assert isinstance(out_masks, tv_tensors.Mask)
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None): if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
assert out_labels is labels assert out_labels is labels
...@@ -1266,15 +1266,15 @@ def test_sanitize_bounding_boxes_no_label(): ...@@ -1266,15 +1266,15 @@ def test_sanitize_bounding_boxes_no_label():
transforms.SanitizeBoundingBoxes()(img, boxes) transforms.SanitizeBoundingBoxes()(img, boxes)
out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes) out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes)
assert isinstance(out_img, datapoints.Image) assert isinstance(out_img, tv_tensors.Image)
assert isinstance(out_boxes, datapoints.BoundingBoxes) assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
def test_sanitize_bounding_boxes_errors(): def test_sanitize_bounding_boxes_errors():
good_bbox = datapoints.BoundingBoxes( good_bbox = tv_tensors.BoundingBoxes(
[[0, 0, 10, 10]], [[0, 0, 10, 10]],
format=datapoints.BoundingBoxFormat.XYXY, format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(20, 20), canvas_size=(20, 20),
) )
......
...@@ -13,7 +13,7 @@ import torch ...@@ -13,7 +13,7 @@ import torch
import torchvision.transforms.v2 as v2_transforms import torchvision.transforms.v2 as v2_transforms
from common_utils import assert_close, assert_equal, set_rng_seed from common_utils import assert_close, assert_equal, set_rng_seed
from torch import nn from torch import nn
from torchvision import datapoints, transforms as legacy_transforms from torchvision import transforms as legacy_transforms, tv_tensors
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.transforms import functional as legacy_F from torchvision.transforms import functional as legacy_F
...@@ -478,15 +478,15 @@ def check_call_consistency( ...@@ -478,15 +478,15 @@ def check_call_consistency(
output_prototype_image = prototype_transform(image) output_prototype_image = prototype_transform(image)
except Exception as exc: except Exception as exc:
raise AssertionError( raise AssertionError(
f"Transforming a image datapoint with shape {image_repr} failed in the prototype transform with " f"Transforming a image tv_tensor with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the " f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`datapoints.Image` path in `_transform`." f"`tv_tensors.Image` path in `_transform`."
) from exc ) from exc
assert_close( assert_close(
output_prototype_image, output_prototype_image,
output_prototype_tensor, output_prototype_tensor,
msg=lambda msg: f"Output for datapoint and tensor images is not equal: \n\n{msg}", msg=lambda msg: f"Output for tv_tensor and tensor images is not equal: \n\n{msg}",
**closeness_kwargs, **closeness_kwargs,
) )
...@@ -747,7 +747,7 @@ class TestAATransforms: ...@@ -747,7 +747,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -812,7 +812,7 @@ class TestAATransforms: ...@@ -812,7 +812,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -887,7 +887,7 @@ class TestAATransforms: ...@@ -887,7 +887,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -964,7 +964,7 @@ class TestAATransforms: ...@@ -964,7 +964,7 @@ class TestAATransforms:
[ [
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8), torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123), PIL.Image.new("RGB", (256, 256), 123),
datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)), tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1030,7 +1030,7 @@ det_transforms = import_transforms_from_references("detection") ...@@ -1030,7 +1030,7 @@ det_transforms = import_transforms_from_references("detection")
class TestRefDetTransforms: class TestRefDetTransforms:
def make_datapoints(self, with_mask=True): def make_tv_tensors(self, with_mask=True):
size = (600, 800) size = (600, 800)
num_objects = 22 num_objects = 22
...@@ -1057,7 +1057,7 @@ class TestRefDetTransforms: ...@@ -1057,7 +1057,7 @@ class TestRefDetTransforms:
yield (tensor_image, target) yield (tensor_image, target)
datapoint_image = make_image(size=size, color_space="RGB", dtype=torch.float32) tv_tensor_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
target = { target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80), "labels": make_label(extra_dims=(num_objects,), categories=80),
...@@ -1065,7 +1065,7 @@ class TestRefDetTransforms: ...@@ -1065,7 +1065,7 @@ class TestRefDetTransforms:
if with_mask: if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long) target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
yield (datapoint_image, target) yield (tv_tensor_image, target)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"t_ref, t, data_kwargs", "t_ref, t, data_kwargs",
...@@ -1095,7 +1095,7 @@ class TestRefDetTransforms: ...@@ -1095,7 +1095,7 @@ class TestRefDetTransforms:
], ],
) )
def test_transform(self, t_ref, t, data_kwargs): def test_transform(self, t_ref, t, data_kwargs):
for dp in self.make_datapoints(**data_kwargs): for dp in self.make_tv_tensors(**data_kwargs):
# We should use prototype transform first as reference transform performs inplace target update # We should use prototype transform first as reference transform performs inplace target update
torch.manual_seed(12) torch.manual_seed(12)
...@@ -1135,7 +1135,7 @@ class PadIfSmaller(v2_transforms.Transform): ...@@ -1135,7 +1135,7 @@ class PadIfSmaller(v2_transforms.Transform):
class TestRefSegTransforms: class TestRefSegTransforms:
def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): def make_tv_tensors(self, supports_pil=True, image_dtype=torch.uint8):
size = (256, 460) size = (256, 460)
num_categories = 21 num_categories = 21
...@@ -1145,13 +1145,13 @@ class TestRefSegTransforms: ...@@ -1145,13 +1145,13 @@ class TestRefSegTransforms:
conv_fns.extend([torch.Tensor, lambda x: x]) conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns: for conv_fn in conv_fns:
datapoint_image = make_image(size=size, color_space="RGB", dtype=image_dtype) tv_tensor_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8) tv_tensor_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
dp = (conv_fn(datapoint_image), datapoint_mask) dp = (conv_fn(tv_tensor_image), tv_tensor_mask)
dp_ref = ( dp_ref = (
to_pil_image(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor), to_pil_image(tv_tensor_image) if supports_pil else tv_tensor_image.as_subclass(torch.Tensor),
to_pil_image(datapoint_mask), to_pil_image(tv_tensor_mask),
) )
yield dp, dp_ref yield dp, dp_ref
...@@ -1161,7 +1161,7 @@ class TestRefSegTransforms: ...@@ -1161,7 +1161,7 @@ class TestRefSegTransforms:
random.seed(seed) random.seed(seed)
def check(self, t, t_ref, data_kwargs=None): def check(self, t, t_ref, data_kwargs=None):
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()): for dp, dp_ref in self.make_tv_tensors(**data_kwargs or dict()):
self.set_seed() self.set_seed()
actual = actual_image, actual_mask = t(dp) actual = actual_image, actual_mask = t(dp)
...@@ -1192,7 +1192,7 @@ class TestRefSegTransforms: ...@@ -1192,7 +1192,7 @@ class TestRefSegTransforms:
seg_transforms.RandomCrop(size=480), seg_transforms.RandomCrop(size=480),
v2_transforms.Compose( v2_transforms.Compose(
[ [
PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}), PadIfSmaller(size=480, fill={tv_tensors.Mask: 255, "others": 0}),
v2_transforms.RandomCrop(size=480), v2_transforms.RandomCrop(size=480),
] ]
), ),
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
from common_utils import assert_close, cache, cpu_and_cuda, needs_cuda, set_rng_seed from common_utils import assert_close, cache, cpu_and_cuda, needs_cuda, set_rng_seed
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import is_pure_tensor from torchvision.transforms.v2._utils import is_pure_tensor
...@@ -164,22 +164,22 @@ class TestKernels: ...@@ -164,22 +164,22 @@ class TestKernels:
def test_batched_vs_single(self, test_id, info, args_kwargs, device): def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device) (batched_input, *other_args), kwargs = args_kwargs.load(device)
datapoint_type = datapoints.Image if is_pure_tensor(batched_input) else type(batched_input) tv_tensor_type = tv_tensors.Image if is_pure_tensor(batched_input) else type(batched_input)
# This dictionary contains the number of rightmost dimensions that contain the actual data. # This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension. # Everything to the left is considered a batch dimension.
data_dims = { data_dims = {
datapoints.Image: 3, tv_tensors.Image: 3,
datapoints.BoundingBoxes: 1, tv_tensors.BoundingBoxes: 1,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks # `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one # it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as # type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground. # common ground.
datapoints.Mask: 2, tv_tensors.Mask: 2,
datapoints.Video: 4, tv_tensors.Video: 4,
}.get(datapoint_type) }.get(tv_tensor_type)
if data_dims is None: if data_dims is None:
raise pytest.UsageError( raise pytest.UsageError(
f"The number of data dimensions cannot be determined for input of type {datapoint_type.__name__}." f"The number of data dimensions cannot be determined for input of type {tv_tensor_type.__name__}."
) from None ) from None
elif batched_input.ndim <= data_dims: elif batched_input.ndim <= data_dims:
pytest.skip("Input is not batched.") pytest.skip("Input is not batched.")
...@@ -305,8 +305,8 @@ def spy_on(mocker): ...@@ -305,8 +305,8 @@ def spy_on(mocker):
class TestDispatchers: class TestDispatchers:
image_sample_inputs = make_info_args_kwargs_parametrization( image_sample_inputs = make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels], [info for info in DISPATCHER_INFOS if tv_tensors.Image in info.kernels],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
) )
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
...@@ -328,8 +328,8 @@ class TestDispatchers: ...@@ -328,8 +328,8 @@ class TestDispatchers:
def test_scripted_smoke(self, info, args_kwargs, device): def test_scripted_smoke(self, info, args_kwargs, device):
dispatcher = script(info.dispatcher) dispatcher = script(info.dispatcher)
(image_datapoint, *other_args), kwargs = args_kwargs.load(device) (image_tv_tensor, *other_args), kwargs = args_kwargs.load(device)
image_pure_tensor = torch.Tensor(image_datapoint) image_pure_tensor = torch.Tensor(image_tv_tensor)
dispatcher(image_pure_tensor, *other_args, **kwargs) dispatcher(image_pure_tensor, *other_args, **kwargs)
...@@ -355,25 +355,25 @@ class TestDispatchers: ...@@ -355,25 +355,25 @@ class TestDispatchers:
@image_sample_inputs @image_sample_inputs
def test_pure_tensor_output_type(self, info, args_kwargs): def test_pure_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load() (image_tv_tensor, *other_args), kwargs = args_kwargs.load()
image_pure_tensor = image_datapoint.as_subclass(torch.Tensor) image_pure_tensor = image_tv_tensor.as_subclass(torch.Tensor)
output = info.dispatcher(image_pure_tensor, *other_args, **kwargs) output = info.dispatcher(image_pure_tensor, *other_args, **kwargs)
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well # We cannot use `isinstance` here since all tv_tensors are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None], [info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image), args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
) )
def test_pil_output_type(self, info, args_kwargs): def test_pil_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load() (image_tv_tensor, *other_args), kwargs = args_kwargs.load()
if image_datapoint.ndim > 3: if image_tv_tensor.ndim > 3:
pytest.skip("Input is batched") pytest.skip("Input is batched")
image_pil = F.to_pil_image(image_datapoint) image_pil = F.to_pil_image(image_tv_tensor)
output = info.dispatcher(image_pil, *other_args, **kwargs) output = info.dispatcher(image_pil, *other_args, **kwargs)
...@@ -383,38 +383,38 @@ class TestDispatchers: ...@@ -383,38 +383,38 @@ class TestDispatchers:
DISPATCHER_INFOS, DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(), args_kwargs_fn=lambda info: info.sample_inputs(),
) )
def test_datapoint_output_type(self, info, args_kwargs): def test_tv_tensor_output_type(self, info, args_kwargs):
(datapoint, *other_args), kwargs = args_kwargs.load() (tv_tensor, *other_args), kwargs = args_kwargs.load()
output = info.dispatcher(datapoint, *other_args, **kwargs) output = info.dispatcher(tv_tensor, *other_args, **kwargs)
assert isinstance(output, type(datapoint)) assert isinstance(output, type(tv_tensor))
if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format: if isinstance(tv_tensor, tv_tensors.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format:
assert output.format == datapoint.format assert output.format == tv_tensor.format
@pytest.mark.parametrize( @pytest.mark.parametrize(
("dispatcher_info", "datapoint_type", "kernel_info"), ("dispatcher_info", "tv_tensor_type", "kernel_info"),
[ [
pytest.param( pytest.param(
dispatcher_info, datapoint_type, kernel_info, id=f"{dispatcher_info.id}-{datapoint_type.__name__}" dispatcher_info, tv_tensor_type, kernel_info, id=f"{dispatcher_info.id}-{tv_tensor_type.__name__}"
) )
for dispatcher_info in DISPATCHER_INFOS for dispatcher_info in DISPATCHER_INFOS
for datapoint_type, kernel_info in dispatcher_info.kernel_infos.items() for tv_tensor_type, kernel_info in dispatcher_info.kernel_infos.items()
], ],
) )
def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoint_type, kernel_info): def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, tv_tensor_type, kernel_info):
dispatcher_signature = inspect.signature(dispatcher_info.dispatcher) dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:] dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
kernel_signature = inspect.signature(kernel_info.kernel) kernel_signature = inspect.signature(kernel_info.kernel)
kernel_params = list(kernel_signature.parameters.values())[1:] kernel_params = list(kernel_signature.parameters.values())[1:]
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be # We filter out metadata that is implicitly passed to the dispatcher through the input tv_tensor, but has to be
# explicitly passed to the kernel. # explicitly passed to the kernel.
input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel) input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
explicit_metadata = { explicit_metadata = {
datapoints.BoundingBoxes: {"format", "canvas_size"}, tv_tensors.BoundingBoxes: {"format", "canvas_size"},
} }
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())] kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
...@@ -445,9 +445,9 @@ class TestDispatchers: ...@@ -445,9 +445,9 @@ class TestDispatchers:
[ [
info info
for info in DISPATCHER_INFOS for info in DISPATCHER_INFOS
if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format if tv_tensors.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format
], ],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes), args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.BoundingBoxes),
) )
def test_bounding_boxes_format_consistency(self, info, args_kwargs): def test_bounding_boxes_format_consistency(self, info, args_kwargs):
(bounding_boxes, *other_args), kwargs = args_kwargs.load() (bounding_boxes, *other_args), kwargs = args_kwargs.load()
...@@ -497,7 +497,7 @@ class TestClampBoundingBoxes: ...@@ -497,7 +497,7 @@ class TestClampBoundingBoxes:
"metadata", "metadata",
[ [
dict(), dict(),
dict(format=datapoints.BoundingBoxFormat.XYXY), dict(format=tv_tensors.BoundingBoxFormat.XYXY),
dict(canvas_size=(1, 1)), dict(canvas_size=(1, 1)),
], ],
) )
...@@ -510,16 +510,16 @@ class TestClampBoundingBoxes: ...@@ -510,16 +510,16 @@ class TestClampBoundingBoxes:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"metadata", "metadata",
[ [
dict(format=datapoints.BoundingBoxFormat.XYXY), dict(format=tv_tensors.BoundingBoxFormat.XYXY),
dict(canvas_size=(1, 1)), dict(canvas_size=(1, 1)),
dict(format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1, 1)), dict(format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
], ],
) )
def test_datapoint_explicit_metadata(self, metadata): def test_tv_tensor_explicit_metadata(self, metadata):
datapoint = next(make_multiple_bounding_boxes()) tv_tensor = next(make_multiple_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")): with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
F.clamp_bounding_boxes(datapoint, **metadata) F.clamp_bounding_boxes(tv_tensor, **metadata)
class TestConvertFormatBoundingBoxes: class TestConvertFormatBoundingBoxes:
...@@ -527,7 +527,7 @@ class TestConvertFormatBoundingBoxes: ...@@ -527,7 +527,7 @@ class TestConvertFormatBoundingBoxes:
("inpt", "old_format"), ("inpt", "old_format"),
[ [
(next(make_multiple_bounding_boxes()), None), (next(make_multiple_bounding_boxes()), None),
(next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY), (next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor), tv_tensors.BoundingBoxFormat.XYXY),
], ],
) )
def test_missing_new_format(self, inpt, old_format): def test_missing_new_format(self, inpt, old_format):
...@@ -538,14 +538,14 @@ class TestConvertFormatBoundingBoxes: ...@@ -538,14 +538,14 @@ class TestConvertFormatBoundingBoxes:
pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor) pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")): with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_bounding_box_format(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH) F.convert_bounding_box_format(pure_tensor, new_format=tv_tensors.BoundingBoxFormat.CXCYWH)
def test_datapoint_explicit_metadata(self): def test_tv_tensor_explicit_metadata(self):
datapoint = next(make_multiple_bounding_boxes()) tv_tensor = next(make_multiple_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")): with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_bounding_box_format( F.convert_bounding_box_format(
datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH tv_tensor, old_format=tv_tensor.format, new_format=tv_tensors.BoundingBoxFormat.CXCYWH
) )
...@@ -579,7 +579,7 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): ...@@ -579,7 +579,7 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"format", "format",
[datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH], [tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH, tv_tensors.BoundingBoxFormat.CXCYWH],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"top, left, height, width, expected_bboxes", "top, left, height, width, expected_bboxes",
...@@ -602,7 +602,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt ...@@ -602,7 +602,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
# out_box = denormalize_bbox(n_out_box, height, width) # out_box = denormalize_bbox(n_out_box, height, width)
# expected_bboxes.append(out_box) # expected_bboxes.append(out_box)
format = datapoints.BoundingBoxFormat.XYXY format = tv_tensors.BoundingBoxFormat.XYXY
canvas_size = (64, 76) canvas_size = (64, 76)
in_boxes = [ in_boxes = [
[10.0, 15.0, 25.0, 35.0], [10.0, 15.0, 25.0, 35.0],
...@@ -610,11 +610,11 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt ...@@ -610,11 +610,11 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
[45.0, 46.0, 56.0, 62.0], [45.0, 46.0, 56.0, 62.0],
] ]
in_boxes = torch.tensor(in_boxes, device=device) in_boxes = torch.tensor(in_boxes, device=device)
if format != datapoints.BoundingBoxFormat.XYXY: if format != tv_tensors.BoundingBoxFormat.XYXY:
in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) in_boxes = convert_bounding_box_format(in_boxes, tv_tensors.BoundingBoxFormat.XYXY, format)
expected_bboxes = clamp_bounding_boxes( expected_bboxes = clamp_bounding_boxes(
datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size) tv_tensors.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
).tolist() ).tolist()
output_boxes, output_canvas_size = F.crop_bounding_boxes( output_boxes, output_canvas_size = F.crop_bounding_boxes(
...@@ -626,8 +626,8 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt ...@@ -626,8 +626,8 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
canvas_size[1], canvas_size[1],
) )
if format != datapoints.BoundingBoxFormat.XYXY: if format != tv_tensors.BoundingBoxFormat.XYXY:
output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) output_boxes = convert_bounding_box_format(output_boxes, format, tv_tensors.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_canvas_size, canvas_size) torch.testing.assert_close(output_canvas_size, canvas_size)
...@@ -648,7 +648,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): ...@@ -648,7 +648,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"format", "format",
[datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH], [tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH, tv_tensors.BoundingBoxFormat.CXCYWH],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"top, left, height, width, size", "top, left, height, width, size",
...@@ -666,7 +666,7 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig ...@@ -666,7 +666,7 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
bbox[3] = (bbox[3] - top_) * size_[0] / height_ bbox[3] = (bbox[3] - top_) * size_[0] / height_
return bbox return bbox
format = datapoints.BoundingBoxFormat.XYXY format = tv_tensors.BoundingBoxFormat.XYXY
canvas_size = (100, 100) canvas_size = (100, 100)
in_boxes = [ in_boxes = [
[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0],
...@@ -677,16 +677,16 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig ...@@ -677,16 +677,16 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size)) expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
expected_bboxes = torch.tensor(expected_bboxes, device=device) expected_bboxes = torch.tensor(expected_bboxes, device=device)
in_boxes = datapoints.BoundingBoxes( in_boxes = tv_tensors.BoundingBoxes(
in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device in_boxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
) )
if format != datapoints.BoundingBoxFormat.XYXY: if format != tv_tensors.BoundingBoxFormat.XYXY:
in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) in_boxes = convert_bounding_box_format(in_boxes, tv_tensors.BoundingBoxFormat.XYXY, format)
output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size) output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
if format != datapoints.BoundingBoxFormat.XYXY: if format != tv_tensors.BoundingBoxFormat.XYXY:
output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY) output_boxes = convert_bounding_box_format(output_boxes, format, tv_tensors.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_canvas_size, size) torch.testing.assert_close(output_canvas_size, size)
...@@ -713,14 +713,14 @@ def test_correctness_pad_bounding_boxes(device, padding): ...@@ -713,14 +713,14 @@ def test_correctness_pad_bounding_boxes(device, padding):
dtype = bbox.dtype dtype = bbox.dtype
bbox = ( bbox = (
bbox.clone() bbox.clone()
if format == datapoints.BoundingBoxFormat.XYXY if format == tv_tensors.BoundingBoxFormat.XYXY
else convert_bounding_box_format(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) else convert_bounding_box_format(bbox, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
) )
bbox[0::2] += pad_left bbox[0::2] += pad_left
bbox[1::2] += pad_up bbox[1::2] += pad_up
bbox = convert_bounding_box_format(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format) bbox = convert_bounding_box_format(bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format)
if bbox.dtype != dtype: if bbox.dtype != dtype:
# Temporary cast to original dtype # Temporary cast to original dtype
# e.g. float32 -> int # e.g. float32 -> int
...@@ -785,7 +785,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -785,7 +785,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
] ]
) )
bbox_xyxy = convert_bounding_box_format(bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY) bbox_xyxy = convert_bounding_box_format(bbox, old_format=format_, new_format=tv_tensors.BoundingBoxFormat.XYXY)
points = np.array( points = np.array(
[ [
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
...@@ -807,7 +807,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -807,7 +807,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
) )
out_bbox = torch.from_numpy(out_bbox) out_bbox = torch.from_numpy(out_bbox)
out_bbox = convert_bounding_box_format( out_bbox = convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_ out_bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format_
) )
return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox) return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
...@@ -846,7 +846,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -846,7 +846,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
def test_correctness_center_crop_bounding_boxes(device, output_size): def test_correctness_center_crop_bounding_boxes(device, output_size):
def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_): def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
dtype = bbox.dtype dtype = bbox.dtype
bbox = convert_bounding_box_format(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH) bbox = convert_bounding_box_format(bbox.float(), format_, tv_tensors.BoundingBoxFormat.XYWH)
if len(output_size_) == 1: if len(output_size_) == 1:
output_size_.append(output_size_[-1]) output_size_.append(output_size_[-1])
...@@ -860,7 +860,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size): ...@@ -860,7 +860,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
bbox[3].item(), bbox[3].item(),
] ]
out_bbox = torch.tensor(out_bbox) out_bbox = torch.tensor(out_bbox)
out_bbox = convert_bounding_box_format(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_) out_bbox = convert_bounding_box_format(out_bbox, tv_tensors.BoundingBoxFormat.XYWH, format_)
out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size) out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
return out_bbox.to(dtype=dtype, device=bbox.device) return out_bbox.to(dtype=dtype, device=bbox.device)
...@@ -958,7 +958,7 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, ...@@ -958,7 +958,7 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize,
torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
) )
image = datapoints.Image(tensor) image = tv_tensors.Image(tensor)
out = fn(image, kernel_size=ksize, sigma=sigma) out = fn(image, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
......
This diff is collapsed.
...@@ -6,46 +6,46 @@ import torch ...@@ -6,46 +6,46 @@ import torch
import torchvision.transforms.v2._utils import torchvision.transforms.v2._utils
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_mask, make_image from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_mask, make_image
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms.v2._utils import has_all, has_any from torchvision.transforms.v2._utils import has_all, has_any
from torchvision.transforms.v2.functional import to_pil_image from torchvision.transforms.v2.functional import to_pil_image
IMAGE = make_image(DEFAULT_SIZE, color_space="RGB") IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=datapoints.BoundingBoxFormat.XYXY) BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY)
MASK = make_detection_mask(DEFAULT_SIZE) MASK = make_detection_mask(DEFAULT_SIZE)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("sample", "types", "expected"), ("sample", "types", "expected"),
[ [
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes,), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes, datapoints.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((MASK,), (datapoints.Image, datapoints.BoundingBoxes), False), ((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
((BOUNDING_BOX,), (datapoints.Image, datapoints.Mask), False), ((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False),
((IMAGE,), (datapoints.BoundingBoxes, datapoints.Mask), False), ((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
True, True,
), ),
((), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False), ((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True), ((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
( (
(torch.Tensor(IMAGE),), (torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
True, True,
), ),
( (
(to_pil_image(IMAGE),), (to_pil_image(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
True, True,
), ),
], ],
...@@ -57,31 +57,31 @@ def test_has_any(sample, types, expected): ...@@ -57,31 +57,31 @@ def test_has_any(sample, types, expected):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("sample", "types", "expected"), ("sample", "types", "expected"),
[ [
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes,), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes, datapoints.Mask), True), ((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
True, True,
), ),
((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), False), ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), False), ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), False),
((IMAGE, MASK), (datapoints.BoundingBoxes, datapoints.Mask), False), ((IMAGE, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
True, True,
), ),
((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False), ((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, MASK), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False), ((IMAGE, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, BOUNDING_BOX), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False), ((IMAGE, BOUNDING_BOX), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
( (
(IMAGE, BOUNDING_BOX, MASK), (IMAGE, BOUNDING_BOX, MASK),
(lambda obj: isinstance(obj, (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)),), (lambda obj: isinstance(obj, (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask)),),
True, True,
), ),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from common_utils import assert_equal, make_bounding_boxes, make_image, make_segmentation_mask, make_video from common_utils import assert_equal, make_bounding_boxes, make_image, make_segmentation_mask, make_video
from PIL import Image from PIL import Image
from torchvision import datapoints from torchvision import tv_tensors
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -13,40 +13,40 @@ def restore_tensor_return_type(): ...@@ -13,40 +13,40 @@ def restore_tensor_return_type():
# This is for security, as we should already be restoring the default manually in each test anyway # This is for security, as we should already be restoring the default manually in each test anyway
# (at least at the time of writing...) # (at least at the time of writing...)
yield yield
datapoints.set_return_type("Tensor") tv_tensors.set_return_type("Tensor")
@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)]) @pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
def test_image_instance(data): def test_image_instance(data):
image = datapoints.Image(data) image = tv_tensors.Image(data)
assert isinstance(image, torch.Tensor) assert isinstance(image, torch.Tensor)
assert image.ndim == 3 and image.shape[0] == 3 assert image.ndim == 3 and image.shape[0] == 3
@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)]) @pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
def test_mask_instance(data): def test_mask_instance(data):
mask = datapoints.Mask(data) mask = tv_tensors.Mask(data)
assert isinstance(mask, torch.Tensor) assert isinstance(mask, torch.Tensor)
assert mask.ndim == 3 and mask.shape[0] == 1 assert mask.ndim == 3 and mask.shape[0] == 1
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]]) @pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH] "format", ["XYXY", "CXCYWH", tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH]
) )
def test_bbox_instance(data, format): def test_bbox_instance(data, format):
bboxes = datapoints.BoundingBoxes(data, format=format, canvas_size=(32, 32)) bboxes = tv_tensors.BoundingBoxes(data, format=format, canvas_size=(32, 32))
assert isinstance(bboxes, torch.Tensor) assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4 assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[(format.upper())] format = tv_tensors.BoundingBoxFormat[(format.upper())]
assert bboxes.format == format assert bboxes.format == format
def test_bbox_dim_error(): def test_bbox_dim_error():
data_3d = [[[1, 2, 3, 4]]] data_3d = [[[1, 2, 3, 4]]]
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"): with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
datapoints.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32)) tv_tensors.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -64,8 +64,8 @@ def test_bbox_dim_error(): ...@@ -64,8 +64,8 @@ def test_bbox_dim_error():
], ],
) )
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad): def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
datapoint = datapoints.Image(data, requires_grad=input_requires_grad) tv_tensor = tv_tensors.Image(data, requires_grad=input_requires_grad)
assert datapoint.requires_grad is expected_requires_grad assert tv_tensor.requires_grad is expected_requires_grad
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
...@@ -75,7 +75,7 @@ def test_isinstance(make_input): ...@@ -75,7 +75,7 @@ def test_isinstance(make_input):
def test_wrapping_no_copy(): def test_wrapping_no_copy():
tensor = torch.rand(3, 16, 16) tensor = torch.rand(3, 16, 16)
image = datapoints.Image(tensor) image = tv_tensors.Image(tensor)
assert image.data_ptr() == tensor.data_ptr() assert image.data_ptr() == tensor.data_ptr()
...@@ -91,25 +91,25 @@ def test_to_wrapping(make_input): ...@@ -91,25 +91,25 @@ def test_to_wrapping(make_input):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"]) @pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_to_datapoint_reference(make_input, return_type): def test_to_tv_tensor_reference(make_input, return_type):
tensor = torch.rand((3, 16, 16), dtype=torch.float64) tensor = torch.rand((3, 16, 16), dtype=torch.float64)
dp = make_input() dp = make_input()
with datapoints.set_return_type(return_type): with tv_tensors.set_return_type(return_type):
tensor_to = tensor.to(dp) tensor_to = tensor.to(dp)
assert type(tensor_to) is (type(dp) if return_type == "datapoint" else torch.Tensor) assert type(tensor_to) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
assert tensor_to.dtype is dp.dtype assert tensor_to.dtype is dp.dtype
assert type(tensor) is torch.Tensor assert type(tensor) is torch.Tensor
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"]) @pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_clone_wrapping(make_input, return_type): def test_clone_wrapping(make_input, return_type):
dp = make_input() dp = make_input()
with datapoints.set_return_type(return_type): with tv_tensors.set_return_type(return_type):
dp_clone = dp.clone() dp_clone = dp.clone()
assert type(dp_clone) is type(dp) assert type(dp_clone) is type(dp)
...@@ -117,13 +117,13 @@ def test_clone_wrapping(make_input, return_type): ...@@ -117,13 +117,13 @@ def test_clone_wrapping(make_input, return_type):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"]) @pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_requires_grad__wrapping(make_input, return_type): def test_requires_grad__wrapping(make_input, return_type):
dp = make_input(dtype=torch.float) dp = make_input(dtype=torch.float)
assert not dp.requires_grad assert not dp.requires_grad
with datapoints.set_return_type(return_type): with tv_tensors.set_return_type(return_type):
dp_requires_grad = dp.requires_grad_(True) dp_requires_grad = dp.requires_grad_(True)
assert type(dp_requires_grad) is type(dp) assert type(dp_requires_grad) is type(dp)
...@@ -132,54 +132,54 @@ def test_requires_grad__wrapping(make_input, return_type): ...@@ -132,54 +132,54 @@ def test_requires_grad__wrapping(make_input, return_type):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"]) @pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_detach_wrapping(make_input, return_type): def test_detach_wrapping(make_input, return_type):
dp = make_input(dtype=torch.float).requires_grad_(True) dp = make_input(dtype=torch.float).requires_grad_(True)
with datapoints.set_return_type(return_type): with tv_tensors.set_return_type(return_type):
dp_detached = dp.detach() dp_detached = dp.detach()
assert type(dp_detached) is type(dp) assert type(dp_detached) is type(dp)
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"]) @pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_force_subclass_with_metadata(return_type): def test_force_subclass_with_metadata(return_type):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and datapoints with metadata # Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata
# Largely the same as above, we additionally check that the metadata is preserved # Largely the same as above, we additionally check that the metadata is preserved
format, canvas_size = "XYXY", (32, 32) format, canvas_size = "XYXY", (32, 32)
bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size) bbox = tv_tensors.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)
datapoints.set_return_type(return_type) tv_tensors.set_return_type(return_type)
bbox = bbox.clone() bbox = bbox.clone()
if return_type == "datapoint": if return_type == "tv_tensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size) assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.to(torch.float64) bbox = bbox.to(torch.float64)
if return_type == "datapoint": if return_type == "tv_tensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size) assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.detach() bbox = bbox.detach()
if return_type == "datapoint": if return_type == "tv_tensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size) assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert not bbox.requires_grad assert not bbox.requires_grad
bbox.requires_grad_(True) bbox.requires_grad_(True)
if return_type == "datapoint": if return_type == "tv_tensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size) assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad assert bbox.requires_grad
datapoints.set_return_type("tensor") tv_tensors.set_return_type("tensor")
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"]) @pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_other_op_no_wrapping(make_input, return_type): def test_other_op_no_wrapping(make_input, return_type):
dp = make_input() dp = make_input()
with datapoints.set_return_type(return_type): with tv_tensors.set_return_type(return_type):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2 output = dp * 2
assert type(output) is (type(dp) if return_type == "datapoint" else torch.Tensor) assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
...@@ -200,15 +200,15 @@ def test_no_tensor_output_op_no_wrapping(make_input, op): ...@@ -200,15 +200,15 @@ def test_no_tensor_output_op_no_wrapping(make_input, op):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"]) @pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_inplace_op_no_wrapping(make_input, return_type): def test_inplace_op_no_wrapping(make_input, return_type):
dp = make_input() dp = make_input()
original_type = type(dp) original_type = type(dp)
with datapoints.set_return_type(return_type): with tv_tensors.set_return_type(return_type):
output = dp.add_(0) output = dp.add_(0)
assert type(output) is (type(dp) if return_type == "datapoint" else torch.Tensor) assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
assert type(dp) is original_type assert type(dp) is original_type
...@@ -219,7 +219,7 @@ def test_wrap(make_input): ...@@ -219,7 +219,7 @@ def test_wrap(make_input):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2 output = dp * 2
dp_new = datapoints.wrap(output, like=dp) dp_new = tv_tensors.wrap(output, like=dp)
assert type(dp_new) is type(dp) assert type(dp_new) is type(dp)
assert dp_new.data_ptr() == output.data_ptr() assert dp_new.data_ptr() == output.data_ptr()
...@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad): ...@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"]) @pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"op", "op",
( (
...@@ -265,10 +265,10 @@ def test_deepcopy(make_input, requires_grad): ...@@ -265,10 +265,10 @@ def test_deepcopy(make_input, requires_grad):
def test_usual_operations(make_input, return_type, op): def test_usual_operations(make_input, return_type, op):
dp = make_input() dp = make_input()
with datapoints.set_return_type(return_type): with tv_tensors.set_return_type(return_type):
out = op(dp) out = op(dp)
assert type(out) is (type(dp) if return_type == "datapoint" else torch.Tensor) assert type(out) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
if isinstance(dp, datapoints.BoundingBoxes) and return_type == "datapoint": if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "tv_tensor":
assert hasattr(out, "format") assert hasattr(out, "format")
assert hasattr(out, "canvas_size") assert hasattr(out, "canvas_size")
...@@ -286,22 +286,22 @@ def test_set_return_type(): ...@@ -286,22 +286,22 @@ def test_set_return_type():
assert type(img + 3) is torch.Tensor assert type(img + 3) is torch.Tensor
with datapoints.set_return_type("datapoint"): with tv_tensors.set_return_type("tv_tensor"):
assert type(img + 3) is datapoints.Image assert type(img + 3) is tv_tensors.Image
assert type(img + 3) is torch.Tensor assert type(img + 3) is torch.Tensor
datapoints.set_return_type("datapoint") tv_tensors.set_return_type("tv_tensor")
assert type(img + 3) is datapoints.Image assert type(img + 3) is tv_tensors.Image
with datapoints.set_return_type("tensor"): with tv_tensors.set_return_type("tensor"):
assert type(img + 3) is torch.Tensor assert type(img + 3) is torch.Tensor
with datapoints.set_return_type("datapoint"): with tv_tensors.set_return_type("tv_tensor"):
assert type(img + 3) is datapoints.Image assert type(img + 3) is tv_tensors.Image
datapoints.set_return_type("tensor") tv_tensors.set_return_type("tensor")
assert type(img + 3) is torch.Tensor assert type(img + 3) is torch.Tensor
assert type(img + 3) is torch.Tensor assert type(img + 3) is torch.Tensor
# Exiting a context manager will restore the return type as it was prior to entering it, # Exiting a context manager will restore the return type as it was prior to entering it,
# regardless of whether the "global" datapoints.set_return_type() was called within the context manager. # regardless of whether the "global" tv_tensors.set_return_type() was called within the context manager.
assert type(img + 3) is datapoints.Image assert type(img + 3) is tv_tensors.Image
datapoints.set_return_type("tensor") tv_tensors.set_return_type("tensor")
...@@ -2,7 +2,7 @@ import collections.abc ...@@ -2,7 +2,7 @@ import collections.abc
import pytest import pytest
import torchvision.transforms.v2.functional as F import torchvision.transforms.v2.functional as F
from torchvision import datapoints from torchvision import tv_tensors
from transforms_v2_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition from transforms_v2_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
from transforms_v2_legacy_utils import InfoBase, TestMark from transforms_v2_legacy_utils import InfoBase, TestMark
...@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase): ...@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
self.pil_kernel_info = pil_kernel_info self.pil_kernel_info = pil_kernel_info
kernel_infos = {} kernel_infos = {}
for datapoint_type, kernel in self.kernels.items(): for tv_tensor_type, kernel in self.kernels.items():
kernel_info = self._KERNEL_INFO_MAP.get(kernel) kernel_info = self._KERNEL_INFO_MAP.get(kernel)
if not kernel_info: if not kernel_info:
raise pytest.UsageError( raise pytest.UsageError(
f"Can't register {kernel.__name__} for type {datapoint_type} since there is no `KernelInfo` for it. " f"Can't register {kernel.__name__} for type {tv_tensor_type} since there is no `KernelInfo` for it. "
f"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`." f"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
) )
kernel_infos[datapoint_type] = kernel_info kernel_infos[tv_tensor_type] = kernel_info
self.kernel_infos = kernel_infos self.kernel_infos = kernel_infos
def sample_inputs(self, *datapoint_types, filter_metadata=True): def sample_inputs(self, *tv_tensor_types, filter_metadata=True):
for datapoint_type in datapoint_types or self.kernel_infos.keys(): for tv_tensor_type in tv_tensor_types or self.kernel_infos.keys():
kernel_info = self.kernel_infos.get(datapoint_type) kernel_info = self.kernel_infos.get(tv_tensor_type)
if not kernel_info: if not kernel_info:
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}") raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
...@@ -69,12 +69,12 @@ class DispatcherInfo(InfoBase): ...@@ -69,12 +69,12 @@ class DispatcherInfo(InfoBase):
import itertools import itertools
for args_kwargs in sample_inputs: for args_kwargs in sample_inputs:
if hasattr(datapoint_type, "__annotations__"): if hasattr(tv_tensor_type, "__annotations__"):
for name in itertools.chain( for name in itertools.chain(
datapoint_type.__annotations__.keys(), tv_tensor_type.__annotations__.keys(),
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a # FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# per-dispatcher level. However, so far there is no option for that. # per-dispatcher level. However, so far there is no option for that.
(f"old_{name}" for name in datapoint_type.__annotations__.keys()), (f"old_{name}" for name in tv_tensor_type.__annotations__.keys()),
): ):
if name in args_kwargs.kwargs: if name in args_kwargs.kwargs:
del args_kwargs.kwargs[name] del args_kwargs.kwargs[name]
...@@ -97,9 +97,9 @@ def xfail_jit_python_scalar_arg(name, *, reason=None): ...@@ -97,9 +97,9 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
) )
skip_dispatch_datapoint = TestMark( skip_dispatch_tv_tensor = TestMark(
("TestDispatchers", "test_dispatch_datapoint"), ("TestDispatchers", "test_dispatch_tv_tensor"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."), pytest.mark.skip(reason="Dispatcher doesn't support arbitrary tv_tensor dispatch."),
) )
multi_crop_skips = [ multi_crop_skips = [
...@@ -107,9 +107,9 @@ multi_crop_skips = [ ...@@ -107,9 +107,9 @@ multi_crop_skips = [
("TestDispatchers", test_name), ("TestDispatchers", test_name),
pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."), pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
) )
for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"] for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_tv_tensor_output_type"]
] ]
multi_crop_skips.append(skip_dispatch_datapoint) multi_crop_skips.append(skip_dispatch_tv_tensor)
def xfails_pil(reason, *, condition=None): def xfails_pil(reason, *, condition=None):
...@@ -142,30 +142,30 @@ DISPATCHER_INFOS = [ ...@@ -142,30 +142,30 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.crop, F.crop,
kernels={ kernels={
datapoints.Image: F.crop_image, tv_tensors.Image: F.crop_image,
datapoints.Video: F.crop_video, tv_tensors.Video: F.crop_video,
datapoints.BoundingBoxes: F.crop_bounding_boxes, tv_tensors.BoundingBoxes: F.crop_bounding_boxes,
datapoints.Mask: F.crop_mask, tv_tensors.Mask: F.crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F._crop_image_pil, kernel_name="crop_image_pil"), pil_kernel_info=PILKernelInfo(F._crop_image_pil, kernel_name="crop_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.resized_crop, F.resized_crop,
kernels={ kernels={
datapoints.Image: F.resized_crop_image, tv_tensors.Image: F.resized_crop_image,
datapoints.Video: F.resized_crop_video, tv_tensors.Video: F.resized_crop_video,
datapoints.BoundingBoxes: F.resized_crop_bounding_boxes, tv_tensors.BoundingBoxes: F.resized_crop_bounding_boxes,
datapoints.Mask: F.resized_crop_mask, tv_tensors.Mask: F.resized_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F._resized_crop_image_pil), pil_kernel_info=PILKernelInfo(F._resized_crop_image_pil),
), ),
DispatcherInfo( DispatcherInfo(
F.pad, F.pad,
kernels={ kernels={
datapoints.Image: F.pad_image, tv_tensors.Image: F.pad_image,
datapoints.Video: F.pad_video, tv_tensors.Video: F.pad_video,
datapoints.BoundingBoxes: F.pad_bounding_boxes, tv_tensors.BoundingBoxes: F.pad_bounding_boxes,
datapoints.Mask: F.pad_mask, tv_tensors.Mask: F.pad_mask,
}, },
pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"), pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[ test_marks=[
...@@ -184,10 +184,10 @@ DISPATCHER_INFOS = [ ...@@ -184,10 +184,10 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.perspective, F.perspective,
kernels={ kernels={
datapoints.Image: F.perspective_image, tv_tensors.Image: F.perspective_image,
datapoints.Video: F.perspective_video, tv_tensors.Video: F.perspective_video,
datapoints.BoundingBoxes: F.perspective_bounding_boxes, tv_tensors.BoundingBoxes: F.perspective_bounding_boxes,
datapoints.Mask: F.perspective_mask, tv_tensors.Mask: F.perspective_mask,
}, },
pil_kernel_info=PILKernelInfo(F._perspective_image_pil), pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
test_marks=[ test_marks=[
...@@ -198,10 +198,10 @@ DISPATCHER_INFOS = [ ...@@ -198,10 +198,10 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.elastic, F.elastic,
kernels={ kernels={
datapoints.Image: F.elastic_image, tv_tensors.Image: F.elastic_image,
datapoints.Video: F.elastic_video, tv_tensors.Video: F.elastic_video,
datapoints.BoundingBoxes: F.elastic_bounding_boxes, tv_tensors.BoundingBoxes: F.elastic_bounding_boxes,
datapoints.Mask: F.elastic_mask, tv_tensors.Mask: F.elastic_mask,
}, },
pil_kernel_info=PILKernelInfo(F._elastic_image_pil), pil_kernel_info=PILKernelInfo(F._elastic_image_pil),
test_marks=[xfail_jit_python_scalar_arg("fill")], test_marks=[xfail_jit_python_scalar_arg("fill")],
...@@ -209,10 +209,10 @@ DISPATCHER_INFOS = [ ...@@ -209,10 +209,10 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.center_crop, F.center_crop,
kernels={ kernels={
datapoints.Image: F.center_crop_image, tv_tensors.Image: F.center_crop_image,
datapoints.Video: F.center_crop_video, tv_tensors.Video: F.center_crop_video,
datapoints.BoundingBoxes: F.center_crop_bounding_boxes, tv_tensors.BoundingBoxes: F.center_crop_bounding_boxes,
datapoints.Mask: F.center_crop_mask, tv_tensors.Mask: F.center_crop_mask,
}, },
pil_kernel_info=PILKernelInfo(F._center_crop_image_pil), pil_kernel_info=PILKernelInfo(F._center_crop_image_pil),
test_marks=[ test_marks=[
...@@ -222,8 +222,8 @@ DISPATCHER_INFOS = [ ...@@ -222,8 +222,8 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.gaussian_blur, F.gaussian_blur,
kernels={ kernels={
datapoints.Image: F.gaussian_blur_image, tv_tensors.Image: F.gaussian_blur_image,
datapoints.Video: F.gaussian_blur_video, tv_tensors.Video: F.gaussian_blur_video,
}, },
pil_kernel_info=PILKernelInfo(F._gaussian_blur_image_pil), pil_kernel_info=PILKernelInfo(F._gaussian_blur_image_pil),
test_marks=[ test_marks=[
...@@ -234,99 +234,99 @@ DISPATCHER_INFOS = [ ...@@ -234,99 +234,99 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.equalize, F.equalize,
kernels={ kernels={
datapoints.Image: F.equalize_image, tv_tensors.Image: F.equalize_image,
datapoints.Video: F.equalize_video, tv_tensors.Video: F.equalize_video,
}, },
pil_kernel_info=PILKernelInfo(F._equalize_image_pil, kernel_name="equalize_image_pil"), pil_kernel_info=PILKernelInfo(F._equalize_image_pil, kernel_name="equalize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.invert, F.invert,
kernels={ kernels={
datapoints.Image: F.invert_image, tv_tensors.Image: F.invert_image,
datapoints.Video: F.invert_video, tv_tensors.Video: F.invert_video,
}, },
pil_kernel_info=PILKernelInfo(F._invert_image_pil, kernel_name="invert_image_pil"), pil_kernel_info=PILKernelInfo(F._invert_image_pil, kernel_name="invert_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.posterize, F.posterize,
kernels={ kernels={
datapoints.Image: F.posterize_image, tv_tensors.Image: F.posterize_image,
datapoints.Video: F.posterize_video, tv_tensors.Video: F.posterize_video,
}, },
pil_kernel_info=PILKernelInfo(F._posterize_image_pil, kernel_name="posterize_image_pil"), pil_kernel_info=PILKernelInfo(F._posterize_image_pil, kernel_name="posterize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.solarize, F.solarize,
kernels={ kernels={
datapoints.Image: F.solarize_image, tv_tensors.Image: F.solarize_image,
datapoints.Video: F.solarize_video, tv_tensors.Video: F.solarize_video,
}, },
pil_kernel_info=PILKernelInfo(F._solarize_image_pil, kernel_name="solarize_image_pil"), pil_kernel_info=PILKernelInfo(F._solarize_image_pil, kernel_name="solarize_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.autocontrast, F.autocontrast,
kernels={ kernels={
datapoints.Image: F.autocontrast_image, tv_tensors.Image: F.autocontrast_image,
datapoints.Video: F.autocontrast_video, tv_tensors.Video: F.autocontrast_video,
}, },
pil_kernel_info=PILKernelInfo(F._autocontrast_image_pil, kernel_name="autocontrast_image_pil"), pil_kernel_info=PILKernelInfo(F._autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_sharpness, F.adjust_sharpness,
kernels={ kernels={
datapoints.Image: F.adjust_sharpness_image, tv_tensors.Image: F.adjust_sharpness_image,
datapoints.Video: F.adjust_sharpness_video, tv_tensors.Video: F.adjust_sharpness_video,
}, },
pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.erase, F.erase,
kernels={ kernels={
datapoints.Image: F.erase_image, tv_tensors.Image: F.erase_image,
datapoints.Video: F.erase_video, tv_tensors.Video: F.erase_video,
}, },
pil_kernel_info=PILKernelInfo(F._erase_image_pil), pil_kernel_info=PILKernelInfo(F._erase_image_pil),
test_marks=[ test_marks=[
skip_dispatch_datapoint, skip_dispatch_tv_tensor,
], ],
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_contrast, F.adjust_contrast,
kernels={ kernels={
datapoints.Image: F.adjust_contrast_image, tv_tensors.Image: F.adjust_contrast_image,
datapoints.Video: F.adjust_contrast_video, tv_tensors.Video: F.adjust_contrast_video,
}, },
pil_kernel_info=PILKernelInfo(F._adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_gamma, F.adjust_gamma,
kernels={ kernels={
datapoints.Image: F.adjust_gamma_image, tv_tensors.Image: F.adjust_gamma_image,
datapoints.Video: F.adjust_gamma_video, tv_tensors.Video: F.adjust_gamma_video,
}, },
pil_kernel_info=PILKernelInfo(F._adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_hue, F.adjust_hue,
kernels={ kernels={
datapoints.Image: F.adjust_hue_image, tv_tensors.Image: F.adjust_hue_image,
datapoints.Video: F.adjust_hue_video, tv_tensors.Video: F.adjust_hue_video,
}, },
pil_kernel_info=PILKernelInfo(F._adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.adjust_saturation, F.adjust_saturation,
kernels={ kernels={
datapoints.Image: F.adjust_saturation_image, tv_tensors.Image: F.adjust_saturation_image,
datapoints.Video: F.adjust_saturation_video, tv_tensors.Video: F.adjust_saturation_video,
}, },
pil_kernel_info=PILKernelInfo(F._adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"), pil_kernel_info=PILKernelInfo(F._adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
), ),
DispatcherInfo( DispatcherInfo(
F.five_crop, F.five_crop,
kernels={ kernels={
datapoints.Image: F.five_crop_image, tv_tensors.Image: F.five_crop_image,
datapoints.Video: F.five_crop_video, tv_tensors.Video: F.five_crop_video,
}, },
pil_kernel_info=PILKernelInfo(F._five_crop_image_pil), pil_kernel_info=PILKernelInfo(F._five_crop_image_pil),
test_marks=[ test_marks=[
...@@ -337,8 +337,8 @@ DISPATCHER_INFOS = [ ...@@ -337,8 +337,8 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.ten_crop, F.ten_crop,
kernels={ kernels={
datapoints.Image: F.ten_crop_image, tv_tensors.Image: F.ten_crop_image,
datapoints.Video: F.ten_crop_video, tv_tensors.Video: F.ten_crop_video,
}, },
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("size"), xfail_jit_python_scalar_arg("size"),
...@@ -349,8 +349,8 @@ DISPATCHER_INFOS = [ ...@@ -349,8 +349,8 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.normalize, F.normalize,
kernels={ kernels={
datapoints.Image: F.normalize_image, tv_tensors.Image: F.normalize_image,
datapoints.Video: F.normalize_video, tv_tensors.Video: F.normalize_video,
}, },
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("mean"), xfail_jit_python_scalar_arg("mean"),
...@@ -360,24 +360,24 @@ DISPATCHER_INFOS = [ ...@@ -360,24 +360,24 @@ DISPATCHER_INFOS = [
DispatcherInfo( DispatcherInfo(
F.uniform_temporal_subsample, F.uniform_temporal_subsample,
kernels={ kernels={
datapoints.Video: F.uniform_temporal_subsample_video, tv_tensors.Video: F.uniform_temporal_subsample_video,
}, },
test_marks=[ test_marks=[
skip_dispatch_datapoint, skip_dispatch_tv_tensor,
], ],
), ),
DispatcherInfo( DispatcherInfo(
F.clamp_bounding_boxes, F.clamp_bounding_boxes,
kernels={datapoints.BoundingBoxes: F.clamp_bounding_boxes}, kernels={tv_tensors.BoundingBoxes: F.clamp_bounding_boxes},
test_marks=[ test_marks=[
skip_dispatch_datapoint, skip_dispatch_tv_tensor,
], ],
), ),
DispatcherInfo( DispatcherInfo(
F.convert_bounding_box_format, F.convert_bounding_box_format,
kernels={datapoints.BoundingBoxes: F.convert_bounding_box_format}, kernels={tv_tensors.BoundingBoxes: F.convert_bounding_box_format},
test_marks=[ test_marks=[
skip_dispatch_datapoint, skip_dispatch_tv_tensor,
], ],
), ),
] ]
...@@ -7,7 +7,7 @@ import pytest ...@@ -7,7 +7,7 @@ import pytest
import torch.testing import torch.testing
import torchvision.ops import torchvision.ops
import torchvision.transforms.v2.functional as F import torchvision.transforms.v2.functional as F
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
from transforms_v2_legacy_utils import ( from transforms_v2_legacy_utils import (
ArgsKwargs, ArgsKwargs,
...@@ -193,7 +193,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -193,7 +193,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
bbox_xyxy = F.convert_bounding_box_format( bbox_xyxy = F.convert_bounding_box_format(
bbox.as_subclass(torch.Tensor), bbox.as_subclass(torch.Tensor),
old_format=format_, old_format=format_,
new_format=datapoints.BoundingBoxFormat.XYXY, new_format=tv_tensors.BoundingBoxFormat.XYXY,
inplace=True, inplace=True,
) )
points = np.array( points = np.array(
...@@ -215,7 +215,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -215,7 +215,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
dtype=bbox_xyxy.dtype, dtype=bbox_xyxy.dtype,
) )
out_bbox = F.convert_bounding_box_format( out_bbox = F.convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True out_bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
) )
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64 # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox = F.clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_) out_bbox = F.clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_)
...@@ -228,7 +228,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -228,7 +228,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
def sample_inputs_convert_bounding_box_format(): def sample_inputs_convert_bounding_box_format():
formats = list(datapoints.BoundingBoxFormat) formats = list(tv_tensors.BoundingBoxFormat)
for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format) yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format)
...@@ -659,7 +659,7 @@ def sample_inputs_perspective_bounding_boxes(): ...@@ -659,7 +659,7 @@ def sample_inputs_perspective_bounding_boxes():
coefficients=_PERSPECTIVE_COEFFS[0], coefficients=_PERSPECTIVE_COEFFS[0],
) )
format = datapoints.BoundingBoxFormat.XYXY format = tv_tensors.BoundingBoxFormat.XYXY
loader = make_bounding_box_loader(format=format) loader = make_bounding_box_loader(format=format)
yield ArgsKwargs( yield ArgsKwargs(
loader, format=format, canvas_size=loader.canvas_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS loader, format=format, canvas_size=loader.canvas_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
......
...@@ -27,7 +27,7 @@ import PIL.Image ...@@ -27,7 +27,7 @@ import PIL.Image
import pytest import pytest
import torch import torch
from torchvision import datapoints from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_dtype_image, to_image, to_pil_image from torchvision.transforms.v2.functional import to_dtype_image, to_image, to_pil_image
...@@ -82,7 +82,7 @@ def make_image( ...@@ -82,7 +82,7 @@ def make_image(
if color_space in {"GRAY_ALPHA", "RGBA"}: if color_space in {"GRAY_ALPHA", "RGBA"}:
data[..., -1, :, :] = max_value data[..., -1, :, :] = max_value
return datapoints.Image(data) return tv_tensors.Image(data)
def make_image_tensor(*args, **kwargs): def make_image_tensor(*args, **kwargs):
...@@ -96,7 +96,7 @@ def make_image_pil(*args, **kwargs): ...@@ -96,7 +96,7 @@ def make_image_pil(*args, **kwargs):
def make_bounding_boxes( def make_bounding_boxes(
canvas_size=DEFAULT_SIZE, canvas_size=DEFAULT_SIZE,
*, *,
format=datapoints.BoundingBoxFormat.XYXY, format=tv_tensors.BoundingBoxFormat.XYXY,
batch_dims=(), batch_dims=(),
dtype=None, dtype=None,
device="cpu", device="cpu",
...@@ -107,12 +107,12 @@ def make_bounding_boxes( ...@@ -107,12 +107,12 @@ def make_bounding_boxes(
return torch.stack([torch.randint(max_value - v, ()) for v in values.flatten().tolist()]).reshape(values.shape) return torch.stack([torch.randint(max_value - v, ()) for v in values.flatten().tolist()]).reshape(values.shape)
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format] format = tv_tensors.BoundingBoxFormat[format]
dtype = dtype or torch.float32 dtype = dtype or torch.float32
if any(dim == 0 for dim in batch_dims): if any(dim == 0 for dim in batch_dims):
return datapoints.BoundingBoxes( return tv_tensors.BoundingBoxes(
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size
) )
...@@ -120,28 +120,28 @@ def make_bounding_boxes( ...@@ -120,28 +120,28 @@ def make_bounding_boxes(
y = sample_position(h, canvas_size[0]) y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1]) x = sample_position(w, canvas_size[1])
if format is datapoints.BoundingBoxFormat.XYWH: if format is tv_tensors.BoundingBoxFormat.XYWH:
parts = (x, y, w, h) parts = (x, y, w, h)
elif format is datapoints.BoundingBoxFormat.XYXY: elif format is tv_tensors.BoundingBoxFormat.XYXY:
x1, y1 = x, y x1, y1 = x, y
x2 = x1 + w x2 = x1 + w
y2 = y1 + h y2 = y1 + h
parts = (x1, y1, x2, y2) parts = (x1, y1, x2, y2)
elif format is datapoints.BoundingBoxFormat.CXCYWH: elif format is tv_tensors.BoundingBoxFormat.CXCYWH:
cx = x + w / 2 cx = x + w / 2
cy = y + h / 2 cy = y + h / 2
parts = (cx, cy, w, h) parts = (cx, cy, w, h)
else: else:
raise ValueError(f"Format {format} is not supported") raise ValueError(f"Format {format} is not supported")
return datapoints.BoundingBoxes( return tv_tensors.BoundingBoxes(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
) )
def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtype=None, device="cpu"): def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtype=None, device="cpu"):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks""" """Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
return datapoints.Mask( return tv_tensors.Mask(
torch.testing.make_tensor( torch.testing.make_tensor(
(*batch_dims, num_objects, *size), (*batch_dims, num_objects, *size),
low=0, low=0,
...@@ -154,7 +154,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp ...@@ -154,7 +154,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"): def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value""" """Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
return datapoints.Mask( return tv_tensors.Mask(
torch.testing.make_tensor( torch.testing.make_tensor(
(*batch_dims, *size), (*batch_dims, *size),
low=0, low=0,
...@@ -166,7 +166,7 @@ def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=( ...@@ -166,7 +166,7 @@ def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(
def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs): def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs)) return tv_tensors.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
def make_video_tensor(*args, **kwargs): def make_video_tensor(*args, **kwargs):
...@@ -335,7 +335,7 @@ def make_image_loader_for_interpolation( ...@@ -335,7 +335,7 @@ def make_image_loader_for_interpolation(
image_tensor = image_tensor.to(device=device) image_tensor = image_tensor.to(device=device)
image_tensor = to_dtype_image(image_tensor, dtype=dtype, scale=True) image_tensor = to_dtype_image(image_tensor, dtype=dtype, scale=True)
return datapoints.Image(image_tensor) return tv_tensors.Image(image_tensor)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format) return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format)
...@@ -352,7 +352,7 @@ def make_image_loaders_for_interpolation( ...@@ -352,7 +352,7 @@ def make_image_loaders_for_interpolation(
@dataclasses.dataclass @dataclasses.dataclass
class BoundingBoxesLoader(TensorLoader): class BoundingBoxesLoader(TensorLoader):
format: datapoints.BoundingBoxFormat format: tv_tensors.BoundingBoxFormat
spatial_size: Tuple[int, int] spatial_size: Tuple[int, int]
canvas_size: Tuple[int, int] = dataclasses.field(init=False) canvas_size: Tuple[int, int] = dataclasses.field(init=False)
...@@ -362,7 +362,7 @@ class BoundingBoxesLoader(TensorLoader): ...@@ -362,7 +362,7 @@ class BoundingBoxesLoader(TensorLoader):
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32): def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format] format = tv_tensors.BoundingBoxFormat[format]
spatial_size = _parse_size(spatial_size, name="spatial_size") spatial_size = _parse_size(spatial_size, name="spatial_size")
...@@ -381,7 +381,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT ...@@ -381,7 +381,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
def make_bounding_box_loaders( def make_bounding_box_loaders(
*, *,
extra_dims=tuple(d for d in DEFAULT_EXTRA_DIMS if len(d) < 2), extra_dims=tuple(d for d in DEFAULT_EXTRA_DIMS if len(d) < 2),
formats=tuple(datapoints.BoundingBoxFormat), formats=tuple(tv_tensors.BoundingBoxFormat),
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64), dtypes=(torch.float32, torch.float64, torch.int64),
): ):
......
...@@ -137,7 +137,7 @@ __all__ = ( ...@@ -137,7 +137,7 @@ __all__ = (
# Ref: https://peps.python.org/pep-0562/ # Ref: https://peps.python.org/pep-0562/
def __getattr__(name): def __getattr__(name):
if name in ("wrap_dataset_for_transforms_v2",): if name in ("wrap_dataset_for_transforms_v2",):
from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2 from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2
return wrap_dataset_for_transforms_v2 return wrap_dataset_for_transforms_v2
......
from . import datapoints, models, transforms, utils from . import models, transforms, tv_tensors, utils
...@@ -6,8 +6,6 @@ import numpy as np ...@@ -6,8 +6,6 @@ import numpy as np
import torch import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.datapoints import BoundingBoxes
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
...@@ -16,6 +14,8 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -16,6 +14,8 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file, read_categories_file,
read_mat, read_mat,
) )
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -4,8 +4,6 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl ...@@ -4,8 +4,6 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl
import torch import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.datapoints import BoundingBoxes
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -14,6 +12,8 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -14,6 +12,8 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
path_accessor, path_accessor,
) )
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -6,8 +6,6 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, U ...@@ -6,8 +6,6 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, U
import numpy as np import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
...@@ -15,6 +13,8 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -15,6 +13,8 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator, path_comparator,
read_categories_file, read_categories_file,
) )
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -2,7 +2,6 @@ import pathlib ...@@ -2,7 +2,6 @@ import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -12,6 +11,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -12,6 +11,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor, path_accessor,
path_comparator, path_comparator,
) )
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -14,8 +14,6 @@ from torchdata.datapipes.iter import ( ...@@ -14,8 +14,6 @@ from torchdata.datapipes.iter import (
Mapper, Mapper,
UnBatcher, UnBatcher,
) )
from torchvision.datapoints import BoundingBoxes, Mask
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -26,6 +24,8 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -26,6 +24,8 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor, path_accessor,
read_categories_file, read_categories_file,
) )
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes, Mask
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -2,7 +2,6 @@ import pathlib ...@@ -2,7 +2,6 @@ import pathlib
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
...@@ -10,6 +9,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -10,6 +9,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator, path_comparator,
read_categories_file, read_categories_file,
) )
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -15,8 +15,6 @@ from torchdata.datapipes.iter import ( ...@@ -15,8 +15,6 @@ from torchdata.datapipes.iter import (
Mapper, Mapper,
) )
from torchdata.datapipes.map import IterToMapConverter from torchdata.datapipes.map import IterToMapConverter
from torchvision.datapoints import BoundingBoxes
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -28,6 +26,8 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -28,6 +26,8 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file, read_categories_file,
read_mat, read_mat,
) )
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
...@@ -3,7 +3,6 @@ import pathlib ...@@ -3,7 +3,6 @@ import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import CSVParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper from torchdata.datapipes.iter import CSVParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -13,6 +12,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -13,6 +12,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator, path_comparator,
read_categories_file, read_categories_file,
) )
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info from .._api import register_dataset, register_info
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment