Unverified Commit d5f4cc38 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)

parent b9447fdd
......@@ -7,11 +7,11 @@ import torch
from common_utils import assert_equal
from prototype_common_utils import make_label
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms
from torchvision.prototype import transforms, tv_tensors
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from transforms_v2_legacy_utils import (
DEFAULT_EXTRA_DIMS,
make_bounding_boxes,
......@@ -51,7 +51,7 @@ class TestSimpleCopyPaste:
# images, batch size = 2
self.create_fake_image(mocker, Image),
# labels, bboxes, masks
mocker.MagicMock(spec=datapoints.Label),
mocker.MagicMock(spec=tv_tensors.Label),
mocker.MagicMock(spec=BoundingBoxes),
mocker.MagicMock(spec=Mask),
# labels, bboxes, masks
......@@ -63,7 +63,7 @@ class TestSimpleCopyPaste:
transform._extract_image_targets(flat_sample)
@pytest.mark.parametrize("image_type", [Image, PIL.Image.Image, torch.Tensor])
@pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel])
@pytest.mark.parametrize("label_type", [tv_tensors.Label, tv_tensors.OneHotLabel])
def test__extract_image_targets(self, image_type, label_type, mocker):
transform = transforms.SimpleCopyPaste()
......@@ -101,7 +101,7 @@ class TestSimpleCopyPaste:
assert isinstance(target[key], type_)
assert target[key] in flat_sample
@pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel])
@pytest.mark.parametrize("label_type", [tv_tensors.Label, tv_tensors.OneHotLabel])
def test__copy_paste(self, label_type):
image = 2 * torch.ones(3, 32, 32)
masks = torch.zeros(2, 32, 32)
......@@ -111,7 +111,7 @@ class TestSimpleCopyPaste:
blending = True
resize_interpolation = InterpolationMode.BILINEAR
antialias = None
if label_type == datapoints.OneHotLabel:
if label_type == tv_tensors.OneHotLabel:
labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = {
"boxes": BoundingBoxes(
......@@ -126,7 +126,7 @@ class TestSimpleCopyPaste:
paste_masks[0, 13:19, 12:18] = 1
paste_masks[1, 15:19, 1:8] = 1
paste_labels = torch.tensor([3, 4])
if label_type == datapoints.OneHotLabel:
if label_type == tv_tensors.OneHotLabel:
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = {
"boxes": BoundingBoxes(
......@@ -148,7 +148,7 @@ class TestSimpleCopyPaste:
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
expected_labels = torch.tensor([1, 2, 3, 4])
if label_type == datapoints.OneHotLabel:
if label_type == tv_tensors.OneHotLabel:
expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5)
torch.testing.assert_close(output_target["labels"], label_type(expected_labels))
......@@ -258,10 +258,10 @@ class TestFixedSizeCrop:
class TestLabelToOneHot:
def test__transform(self):
categories = ["apple", "pear", "pineapple"]
labels = datapoints.Label(torch.tensor([0, 1, 2, 1]), categories=categories)
labels = tv_tensors.Label(torch.tensor([0, 1, 2, 1]), categories=categories)
transform = transforms.LabelToOneHot()
ohe_labels = transform(labels)
assert isinstance(ohe_labels, datapoints.OneHotLabel)
assert isinstance(ohe_labels, tv_tensors.OneHotLabel)
assert ohe_labels.shape == (4, 3)
assert ohe_labels.categories == labels.categories == categories
......@@ -383,7 +383,7 @@ det_transforms = import_transforms_from_references("detection")
def test_fixed_sized_crop_against_detection_reference():
def make_datapoints():
def make_tv_tensors():
size = (600, 800)
num_objects = 22
......@@ -405,19 +405,19 @@ def test_fixed_sized_crop_against_detection_reference():
yield (tensor_image, target)
datapoint_image = make_image(size=size, color_space="RGB")
tv_tensor_image = make_image(size=size, color_space="RGB")
target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
}
yield (datapoint_image, target)
yield (tv_tensor_image, target)
t = transforms.FixedSizeCrop((1024, 1024), fill=0)
t_ref = det_transforms.FixedSizeCrop((1024, 1024), fill=0)
for dp in make_datapoints():
for dp in make_tv_tensors():
# We should use prototype transform first as reference transform performs inplace target update
torch.manual_seed(12)
output = t(dp)
......
......@@ -13,7 +13,7 @@ import torchvision.transforms.v2 as transforms
from common_utils import assert_equal, cpu_and_cuda
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F
......@@ -66,10 +66,10 @@ def auto_augment_adapter(transform, input, device):
adapted_input = {}
image_or_video_found = False
for key, value in input.items():
if isinstance(value, (datapoints.BoundingBoxes, datapoints.Mask)):
if isinstance(value, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
# AA transforms don't support bounding boxes or masks
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor, PIL.Image.Image)):
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor, PIL.Image.Image)):
if image_or_video_found:
# AA transforms only support a single image or video
continue
......@@ -99,7 +99,7 @@ def normalize_adapter(transform, input, device):
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor)):
elif check_type(value, (tv_tensors.Image, tv_tensors.Video, is_pure_tensor)):
# normalize doesn't support integer images
value = F.to_dtype(value, torch.float32, scale=True)
adapted_input[key] = value
......@@ -142,7 +142,7 @@ class TestSmoke:
(transforms.Resize([16, 16], antialias=True), None),
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
(transforms.ClampBoundingBoxes(), None),
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertImageDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None),
(
......@@ -178,19 +178,19 @@ class TestSmoke:
canvas_size = F.get_size(image_or_video)
input = dict(
image_or_video=image_or_video,
image_datapoint=make_image(size=canvas_size),
video_datapoint=make_video(size=canvas_size),
image_tv_tensor=make_image(size=canvas_size),
video_tv_tensor=make_video(size=canvas_size),
image_pil=next(make_pil_images(sizes=[canvas_size], color_spaces=["RGB"])),
bounding_boxes_xyxy=make_bounding_boxes(
format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(3,)
format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(3,)
),
bounding_boxes_xywh=make_bounding_boxes(
format=datapoints.BoundingBoxFormat.XYWH, canvas_size=canvas_size, batch_dims=(4,)
format=tv_tensors.BoundingBoxFormat.XYWH, canvas_size=canvas_size, batch_dims=(4,)
),
bounding_boxes_cxcywh=make_bounding_boxes(
format=datapoints.BoundingBoxFormat.CXCYWH, canvas_size=canvas_size, batch_dims=(5,)
format=tv_tensors.BoundingBoxFormat.CXCYWH, canvas_size=canvas_size, batch_dims=(5,)
),
bounding_boxes_degenerate_xyxy=datapoints.BoundingBoxes(
bounding_boxes_degenerate_xyxy=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
......@@ -199,10 +199,10 @@ class TestSmoke:
[0, 2, 1, 1], # x1 < x2, y1 > y2
[2, 2, 1, 1], # x1 > x2, y1 > y2
],
format=datapoints.BoundingBoxFormat.XYXY,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=canvas_size,
),
bounding_boxes_degenerate_xywh=datapoints.BoundingBoxes(
bounding_boxes_degenerate_xywh=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
......@@ -211,10 +211,10 @@ class TestSmoke:
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=datapoints.BoundingBoxFormat.XYWH,
format=tv_tensors.BoundingBoxFormat.XYWH,
canvas_size=canvas_size,
),
bounding_boxes_degenerate_cxcywh=datapoints.BoundingBoxes(
bounding_boxes_degenerate_cxcywh=tv_tensors.BoundingBoxes(
[
[0, 0, 0, 0], # no height or width
[0, 0, 0, 1], # no height
......@@ -223,7 +223,7 @@ class TestSmoke:
[0, 0, -1, 1], # negative width
[0, 0, -1, -1], # negative height and width
],
format=datapoints.BoundingBoxFormat.CXCYWH,
format=tv_tensors.BoundingBoxFormat.CXCYWH,
canvas_size=canvas_size,
),
detection_mask=make_detection_mask(size=canvas_size),
......@@ -262,7 +262,7 @@ class TestSmoke:
else:
assert output_item is input_item
if isinstance(input_item, datapoints.BoundingBoxes) and not isinstance(
if isinstance(input_item, tv_tensors.BoundingBoxes) and not isinstance(
transform, transforms.ConvertBoundingBoxFormat
):
assert output_item.format == input_item.format
......@@ -270,9 +270,9 @@ class TestSmoke:
# Enforce that the transform does not turn a degenerate box marked by RandomIoUCrop (or any other future
# transform that does this), back into a valid one.
# TODO: we should test that against all degenerate boxes above
for format in list(datapoints.BoundingBoxFormat):
for format in list(tv_tensors.BoundingBoxFormat):
sample = dict(
boxes=datapoints.BoundingBoxes([[0, 0, 0, 0]], format=format, canvas_size=(224, 244)),
boxes=tv_tensors.BoundingBoxes([[0, 0, 0, 0]], format=format, canvas_size=(224, 244)),
labels=torch.tensor([3]),
)
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
......@@ -652,7 +652,7 @@ class TestRandomErasing:
class TestTransform:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
)
def test_check_transformed_types(self, inpt_type, mocker):
# This test ensures that we correctly handle which types to transform and which to bypass
......@@ -670,7 +670,7 @@ class TestTransform:
class TestToImage:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch(
......@@ -681,7 +681,7 @@ class TestToImage:
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImage()
transform(inpt)
if inpt_type in (datapoints.BoundingBoxes, datapoints.Image, str, int):
if inpt_type in (tv_tensors.BoundingBoxes, tv_tensors.Image, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt)
......@@ -690,7 +690,7 @@ class TestToImage:
class TestToPILImage:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.v2.functional.to_pil_image")
......@@ -698,7 +698,7 @@ class TestToPILImage:
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToPILImage()
transform(inpt)
if inpt_type in (PIL.Image.Image, datapoints.BoundingBoxes, str, int):
if inpt_type in (PIL.Image.Image, tv_tensors.BoundingBoxes, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt, mode=transform.mode)
......@@ -707,7 +707,7 @@ class TestToPILImage:
class TestToTensor:
@pytest.mark.parametrize(
"inpt_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int],
[torch.Tensor, PIL.Image.Image, tv_tensors.Image, np.ndarray, tv_tensors.BoundingBoxes, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.transforms.functional.to_tensor")
......@@ -716,7 +716,7 @@ class TestToTensor:
with pytest.warns(UserWarning, match="deprecated and will be removed"):
transform = transforms.ToTensor()
transform(inpt)
if inpt_type in (datapoints.Image, torch.Tensor, datapoints.BoundingBoxes, str, int):
if inpt_type in (tv_tensors.Image, torch.Tensor, tv_tensors.BoundingBoxes, str, int):
assert fn.call_count == 0
else:
fn.assert_called_once_with(inpt)
......@@ -757,7 +757,7 @@ class TestRandomIoUCrop:
def test__get_params(self, device, options):
orig_h, orig_w = size = (24, 32)
image = make_image(size)
bboxes = datapoints.BoundingBoxes(
bboxes = tv_tensors.BoundingBoxes(
torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]),
format="XYXY",
canvas_size=size,
......@@ -792,8 +792,8 @@ class TestRandomIoUCrop:
def test__transform_empty_params(self, mocker):
transform = transforms.RandomIoUCrop(sampler_options=[2.0])
image = datapoints.Image(torch.rand(1, 3, 4, 4))
bboxes = datapoints.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4))
image = tv_tensors.Image(torch.rand(1, 3, 4, 4))
bboxes = tv_tensors.BoundingBoxes(torch.tensor([[1, 1, 2, 2]]), format="XYXY", canvas_size=(4, 4))
label = torch.tensor([1])
sample = [image, bboxes, label]
# Let's mock transform._get_params to control the output:
......@@ -827,11 +827,11 @@ class TestRandomIoUCrop:
# check number of bboxes vs number of labels:
output_bboxes = output[1]
assert isinstance(output_bboxes, datapoints.BoundingBoxes)
assert isinstance(output_bboxes, tv_tensors.BoundingBoxes)
assert (output_bboxes[~is_within_crop_area] == 0).all()
output_masks = output[2]
assert isinstance(output_masks, datapoints.Mask)
assert isinstance(output_masks, tv_tensors.Mask)
class TestScaleJitter:
......@@ -899,7 +899,7 @@ class TestLinearTransformation:
[
122 * torch.ones(1, 3, 8, 8),
122.0 * torch.ones(1, 3, 8, 8),
datapoints.Image(122 * torch.ones(1, 3, 8, 8)),
tv_tensors.Image(122 * torch.ones(1, 3, 8, 8)),
PIL.Image.new("RGB", (8, 8), (122, 122, 122)),
],
)
......@@ -941,7 +941,7 @@ class TestUniformTemporalSubsample:
[
torch.zeros(10, 3, 8, 8),
torch.zeros(1, 10, 3, 8, 8),
datapoints.Video(torch.zeros(1, 10, 3, 8, 8)),
tv_tensors.Video(torch.zeros(1, 10, 3, 8, 8)),
],
)
def test__transform(self, inpt):
......@@ -971,12 +971,12 @@ def test_antialias_warning():
transforms.RandomResize(10, 20)(tensor_img)
with pytest.warns(UserWarning, match=match):
F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20))
F.resized_crop(tv_tensors.Image(tensor_img), 0, 0, 10, 10, (20, 20))
with pytest.warns(UserWarning, match=match):
F.resize(datapoints.Video(tensor_video), (20, 20))
F.resize(tv_tensors.Video(tensor_video), (20, 20))
with pytest.warns(UserWarning, match=match):
F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20))
F.resized_crop(tv_tensors.Video(tensor_video), 0, 0, 10, 10, (20, 20))
with warnings.catch_warnings():
warnings.simplefilter("error")
......@@ -990,17 +990,17 @@ def test_antialias_warning():
transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
transforms.RandomResize(10, 20, antialias=True)(tensor_img)
F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20), antialias=True)
F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20), antialias=True)
F.resized_crop(tv_tensors.Image(tensor_img), 0, 0, 10, 10, (20, 20), antialias=True)
F.resized_crop(tv_tensors.Video(tensor_video), 0, 0, 10, 10, (20, 20), antialias=True)
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, int))
@pytest.mark.parametrize("dataset_return_type", (dict, tuple))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8))
image = tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8))
if image_type is PIL.Image:
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
......@@ -1056,7 +1056,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
assert out_label == label
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))
@pytest.mark.parametrize("sanitize", (True, False))
......@@ -1082,7 +1082,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
# leaving FixedSizeCrop in prototype for now, and it expects Label
# classes which we won't release yet.
# transforms.FixedSizeCrop(
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})
# size=(1024, 1024), fill=defaultdict(lambda: (123.0, 117.0, 104.0), {tv_tensors.Mask: 0})
# ),
transforms.RandomCrop((1024, 1024), pad_if_needed=True),
transforms.RandomHorizontalFlip(p=1),
......@@ -1101,7 +1101,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
elif data_augmentation == "ssd":
t = [
transforms.RandomPhotometricDistort(p=1),
transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), datapoints.Mask: 0}, p=1),
transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), tv_tensors.Mask: 0}, p=1),
transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1),
to_tensor,
......@@ -1121,7 +1121,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
num_boxes = 5
H = W = 250
image = datapoints.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8))
image = tv_tensors.Image(torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8))
if image_type is PIL.Image:
image = to_pil_image(image[0])
elif image_type is torch.Tensor:
......@@ -1133,9 +1133,9 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
boxes[:, 2:] += boxes[:, :2]
boxes = boxes.clamp(min=0, max=min(H, W))
boxes = datapoints.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))
boxes = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=(H, W))
masks = datapoints.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8))
masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_boxes, H, W), dtype=torch.uint8))
sample = {
"image": image,
......@@ -1146,10 +1146,10 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
out = t(sample)
if isinstance(to_tensor, transforms.ToTensor) and image_type is not datapoints.Image:
if isinstance(to_tensor, transforms.ToTensor) and image_type is not tv_tensors.Image:
assert is_pure_tensor(out["image"])
else:
assert isinstance(out["image"], datapoints.Image)
assert isinstance(out["image"], tv_tensors.Image)
assert isinstance(out["label"], type(sample["label"]))
num_boxes_expected = {
......@@ -1204,13 +1204,13 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
boxes = torch.tensor(boxes)
labels = torch.arange(boxes.shape[0])
boxes = datapoints.BoundingBoxes(
boxes = tv_tensors.BoundingBoxes(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(H, W),
)
masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
whatever = torch.rand(10)
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
sample = {
......@@ -1244,8 +1244,8 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_image is input_img
assert out_whatever is whatever
assert isinstance(out_boxes, datapoints.BoundingBoxes)
assert isinstance(out_masks, datapoints.Mask)
assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
assert isinstance(out_masks, tv_tensors.Mask)
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
assert out_labels is labels
......@@ -1266,15 +1266,15 @@ def test_sanitize_bounding_boxes_no_label():
transforms.SanitizeBoundingBoxes()(img, boxes)
out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes)
assert isinstance(out_img, datapoints.Image)
assert isinstance(out_boxes, datapoints.BoundingBoxes)
assert isinstance(out_img, tv_tensors.Image)
assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
def test_sanitize_bounding_boxes_errors():
good_bbox = datapoints.BoundingBoxes(
good_bbox = tv_tensors.BoundingBoxes(
[[0, 0, 10, 10]],
format=datapoints.BoundingBoxFormat.XYXY,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(20, 20),
)
......
......@@ -13,7 +13,7 @@ import torch
import torchvision.transforms.v2 as v2_transforms
from common_utils import assert_close, assert_equal, set_rng_seed
from torch import nn
from torchvision import datapoints, transforms as legacy_transforms
from torchvision import transforms as legacy_transforms, tv_tensors
from torchvision._utils import sequence_to_str
from torchvision.transforms import functional as legacy_F
......@@ -478,15 +478,15 @@ def check_call_consistency(
output_prototype_image = prototype_transform(image)
except Exception as exc:
raise AssertionError(
f"Transforming a image datapoint with shape {image_repr} failed in the prototype transform with "
f"Transforming a image tv_tensor with shape {image_repr} failed in the prototype transform with "
f"the error above. This means there is a consistency bug either in `_get_params` or in the "
f"`datapoints.Image` path in `_transform`."
f"`tv_tensors.Image` path in `_transform`."
) from exc
assert_close(
output_prototype_image,
output_prototype_tensor,
msg=lambda msg: f"Output for datapoint and tensor images is not equal: \n\n{msg}",
msg=lambda msg: f"Output for tv_tensor and tensor images is not equal: \n\n{msg}",
**closeness_kwargs,
)
......@@ -747,7 +747,7 @@ class TestAATransforms:
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
......@@ -812,7 +812,7 @@ class TestAATransforms:
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
......@@ -887,7 +887,7 @@ class TestAATransforms:
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
......@@ -964,7 +964,7 @@ class TestAATransforms:
[
torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8),
PIL.Image.new("RGB", (256, 256), 123),
datapoints.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
tv_tensors.Image(torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)),
],
)
@pytest.mark.parametrize(
......@@ -1030,7 +1030,7 @@ det_transforms = import_transforms_from_references("detection")
class TestRefDetTransforms:
def make_datapoints(self, with_mask=True):
def make_tv_tensors(self, with_mask=True):
size = (600, 800)
num_objects = 22
......@@ -1057,7 +1057,7 @@ class TestRefDetTransforms:
yield (tensor_image, target)
datapoint_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
tv_tensor_image = make_image(size=size, color_space="RGB", dtype=torch.float32)
target = {
"boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
......@@ -1065,7 +1065,7 @@ class TestRefDetTransforms:
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)
yield (datapoint_image, target)
yield (tv_tensor_image, target)
@pytest.mark.parametrize(
"t_ref, t, data_kwargs",
......@@ -1095,7 +1095,7 @@ class TestRefDetTransforms:
],
)
def test_transform(self, t_ref, t, data_kwargs):
for dp in self.make_datapoints(**data_kwargs):
for dp in self.make_tv_tensors(**data_kwargs):
# We should use prototype transform first as reference transform performs inplace target update
torch.manual_seed(12)
......@@ -1135,7 +1135,7 @@ class PadIfSmaller(v2_transforms.Transform):
class TestRefSegTransforms:
def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
def make_tv_tensors(self, supports_pil=True, image_dtype=torch.uint8):
size = (256, 460)
num_categories = 21
......@@ -1145,13 +1145,13 @@ class TestRefSegTransforms:
conv_fns.extend([torch.Tensor, lambda x: x])
for conv_fn in conv_fns:
datapoint_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
datapoint_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
tv_tensor_image = make_image(size=size, color_space="RGB", dtype=image_dtype)
tv_tensor_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)
dp = (conv_fn(datapoint_image), datapoint_mask)
dp = (conv_fn(tv_tensor_image), tv_tensor_mask)
dp_ref = (
to_pil_image(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor),
to_pil_image(datapoint_mask),
to_pil_image(tv_tensor_image) if supports_pil else tv_tensor_image.as_subclass(torch.Tensor),
to_pil_image(tv_tensor_mask),
)
yield dp, dp_ref
......@@ -1161,7 +1161,7 @@ class TestRefSegTransforms:
random.seed(seed)
def check(self, t, t_ref, data_kwargs=None):
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):
for dp, dp_ref in self.make_tv_tensors(**data_kwargs or dict()):
self.set_seed()
actual = actual_image, actual_mask = t(dp)
......@@ -1192,7 +1192,7 @@ class TestRefSegTransforms:
seg_transforms.RandomCrop(size=480),
v2_transforms.Compose(
[
PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}),
PadIfSmaller(size=480, fill={tv_tensors.Mask: 255, "others": 0}),
v2_transforms.RandomCrop(size=480),
]
),
......
......@@ -10,7 +10,7 @@ import torch
from common_utils import assert_close, cache, cpu_and_cuda, needs_cuda, set_rng_seed
from torch.utils._pytree import tree_map
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import is_pure_tensor
......@@ -164,22 +164,22 @@ class TestKernels:
def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device)
datapoint_type = datapoints.Image if is_pure_tensor(batched_input) else type(batched_input)
tv_tensor_type = tv_tensors.Image if is_pure_tensor(batched_input) else type(batched_input)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
data_dims = {
datapoints.Image: 3,
datapoints.BoundingBoxes: 1,
tv_tensors.Image: 3,
tv_tensors.BoundingBoxes: 1,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground.
datapoints.Mask: 2,
datapoints.Video: 4,
}.get(datapoint_type)
tv_tensors.Mask: 2,
tv_tensors.Video: 4,
}.get(tv_tensor_type)
if data_dims is None:
raise pytest.UsageError(
f"The number of data dimensions cannot be determined for input of type {datapoint_type.__name__}."
f"The number of data dimensions cannot be determined for input of type {tv_tensor_type.__name__}."
) from None
elif batched_input.ndim <= data_dims:
pytest.skip("Input is not batched.")
......@@ -305,8 +305,8 @@ def spy_on(mocker):
class TestDispatchers:
image_sample_inputs = make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if datapoints.Image in info.kernels],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
[info for info in DISPATCHER_INFOS if tv_tensors.Image in info.kernels],
args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
)
@make_info_args_kwargs_parametrization(
......@@ -328,8 +328,8 @@ class TestDispatchers:
def test_scripted_smoke(self, info, args_kwargs, device):
dispatcher = script(info.dispatcher)
(image_datapoint, *other_args), kwargs = args_kwargs.load(device)
image_pure_tensor = torch.Tensor(image_datapoint)
(image_tv_tensor, *other_args), kwargs = args_kwargs.load(device)
image_pure_tensor = torch.Tensor(image_tv_tensor)
dispatcher(image_pure_tensor, *other_args, **kwargs)
......@@ -355,25 +355,25 @@ class TestDispatchers:
@image_sample_inputs
def test_pure_tensor_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
image_pure_tensor = image_datapoint.as_subclass(torch.Tensor)
(image_tv_tensor, *other_args), kwargs = args_kwargs.load()
image_pure_tensor = image_tv_tensor.as_subclass(torch.Tensor)
output = info.dispatcher(image_pure_tensor, *other_args, **kwargs)
# We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well
# We cannot use `isinstance` here since all tv_tensors are instances of `torch.Tensor` as well
assert type(output) is torch.Tensor
@make_info_args_kwargs_parametrization(
[info for info in DISPATCHER_INFOS if info.pil_kernel_info is not None],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.Image),
args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.Image),
)
def test_pil_output_type(self, info, args_kwargs):
(image_datapoint, *other_args), kwargs = args_kwargs.load()
(image_tv_tensor, *other_args), kwargs = args_kwargs.load()
if image_datapoint.ndim > 3:
if image_tv_tensor.ndim > 3:
pytest.skip("Input is batched")
image_pil = F.to_pil_image(image_datapoint)
image_pil = F.to_pil_image(image_tv_tensor)
output = info.dispatcher(image_pil, *other_args, **kwargs)
......@@ -383,38 +383,38 @@ class TestDispatchers:
DISPATCHER_INFOS,
args_kwargs_fn=lambda info: info.sample_inputs(),
)
def test_datapoint_output_type(self, info, args_kwargs):
(datapoint, *other_args), kwargs = args_kwargs.load()
def test_tv_tensor_output_type(self, info, args_kwargs):
(tv_tensor, *other_args), kwargs = args_kwargs.load()
output = info.dispatcher(datapoint, *other_args, **kwargs)
output = info.dispatcher(tv_tensor, *other_args, **kwargs)
assert isinstance(output, type(datapoint))
assert isinstance(output, type(tv_tensor))
if isinstance(datapoint, datapoints.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format:
assert output.format == datapoint.format
if isinstance(tv_tensor, tv_tensors.BoundingBoxes) and info.dispatcher is not F.convert_bounding_box_format:
assert output.format == tv_tensor.format
@pytest.mark.parametrize(
("dispatcher_info", "datapoint_type", "kernel_info"),
("dispatcher_info", "tv_tensor_type", "kernel_info"),
[
pytest.param(
dispatcher_info, datapoint_type, kernel_info, id=f"{dispatcher_info.id}-{datapoint_type.__name__}"
dispatcher_info, tv_tensor_type, kernel_info, id=f"{dispatcher_info.id}-{tv_tensor_type.__name__}"
)
for dispatcher_info in DISPATCHER_INFOS
for datapoint_type, kernel_info in dispatcher_info.kernel_infos.items()
for tv_tensor_type, kernel_info in dispatcher_info.kernel_infos.items()
],
)
def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoint_type, kernel_info):
def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, tv_tensor_type, kernel_info):
dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
kernel_signature = inspect.signature(kernel_info.kernel)
kernel_params = list(kernel_signature.parameters.values())[1:]
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
# We filter out metadata that is implicitly passed to the dispatcher through the input tv_tensor, but has to be
# explicitly passed to the kernel.
input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
explicit_metadata = {
datapoints.BoundingBoxes: {"format", "canvas_size"},
tv_tensors.BoundingBoxes: {"format", "canvas_size"},
}
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
......@@ -445,9 +445,9 @@ class TestDispatchers:
[
info
for info in DISPATCHER_INFOS
if datapoints.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format
if tv_tensors.BoundingBoxes in info.kernels and info.dispatcher is not F.convert_bounding_box_format
],
args_kwargs_fn=lambda info: info.sample_inputs(datapoints.BoundingBoxes),
args_kwargs_fn=lambda info: info.sample_inputs(tv_tensors.BoundingBoxes),
)
def test_bounding_boxes_format_consistency(self, info, args_kwargs):
(bounding_boxes, *other_args), kwargs = args_kwargs.load()
......@@ -497,7 +497,7 @@ class TestClampBoundingBoxes:
"metadata",
[
dict(),
dict(format=datapoints.BoundingBoxFormat.XYXY),
dict(format=tv_tensors.BoundingBoxFormat.XYXY),
dict(canvas_size=(1, 1)),
],
)
......@@ -510,16 +510,16 @@ class TestClampBoundingBoxes:
@pytest.mark.parametrize(
"metadata",
[
dict(format=datapoints.BoundingBoxFormat.XYXY),
dict(format=tv_tensors.BoundingBoxFormat.XYXY),
dict(canvas_size=(1, 1)),
dict(format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
dict(format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=(1, 1)),
],
)
def test_datapoint_explicit_metadata(self, metadata):
datapoint = next(make_multiple_bounding_boxes())
def test_tv_tensor_explicit_metadata(self, metadata):
tv_tensor = next(make_multiple_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` must not be passed")):
F.clamp_bounding_boxes(datapoint, **metadata)
F.clamp_bounding_boxes(tv_tensor, **metadata)
class TestConvertFormatBoundingBoxes:
......@@ -527,7 +527,7 @@ class TestConvertFormatBoundingBoxes:
("inpt", "old_format"),
[
(next(make_multiple_bounding_boxes()), None),
(next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
(next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor), tv_tensors.BoundingBoxFormat.XYXY),
],
)
def test_missing_new_format(self, inpt, old_format):
......@@ -538,14 +538,14 @@ class TestConvertFormatBoundingBoxes:
pure_tensor = next(make_multiple_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_bounding_box_format(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
F.convert_bounding_box_format(pure_tensor, new_format=tv_tensors.BoundingBoxFormat.CXCYWH)
def test_datapoint_explicit_metadata(self):
datapoint = next(make_multiple_bounding_boxes())
def test_tv_tensor_explicit_metadata(self):
tv_tensor = next(make_multiple_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_bounding_box_format(
datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
tv_tensor, old_format=tv_tensor.format, new_format=tv_tensors.BoundingBoxFormat.CXCYWH
)
......@@ -579,7 +579,7 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize(
"format",
[datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
[tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH, tv_tensors.BoundingBoxFormat.CXCYWH],
)
@pytest.mark.parametrize(
"top, left, height, width, expected_bboxes",
......@@ -602,7 +602,7 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
# out_box = denormalize_bbox(n_out_box, height, width)
# expected_bboxes.append(out_box)
format = datapoints.BoundingBoxFormat.XYXY
format = tv_tensors.BoundingBoxFormat.XYXY
canvas_size = (64, 76)
in_boxes = [
[10.0, 15.0, 25.0, 35.0],
......@@ -610,11 +610,11 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
[45.0, 46.0, 56.0, 62.0],
]
in_boxes = torch.tensor(in_boxes, device=device)
if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
if format != tv_tensors.BoundingBoxFormat.XYXY:
in_boxes = convert_bounding_box_format(in_boxes, tv_tensors.BoundingBoxFormat.XYXY, format)
expected_bboxes = clamp_bounding_boxes(
datapoints.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
tv_tensors.BoundingBoxes(expected_bboxes, format="XYXY", canvas_size=canvas_size)
).tolist()
output_boxes, output_canvas_size = F.crop_bounding_boxes(
......@@ -626,8 +626,8 @@ def test_correctness_crop_bounding_boxes(device, format, top, left, height, widt
canvas_size[1],
)
if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
if format != tv_tensors.BoundingBoxFormat.XYXY:
output_boxes = convert_bounding_box_format(output_boxes, format, tv_tensors.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_canvas_size, canvas_size)
......@@ -648,7 +648,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize(
"format",
[datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH, datapoints.BoundingBoxFormat.CXCYWH],
[tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH, tv_tensors.BoundingBoxFormat.CXCYWH],
)
@pytest.mark.parametrize(
"top, left, height, width, size",
......@@ -666,7 +666,7 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
bbox[3] = (bbox[3] - top_) * size_[0] / height_
return bbox
format = datapoints.BoundingBoxFormat.XYXY
format = tv_tensors.BoundingBoxFormat.XYXY
canvas_size = (100, 100)
in_boxes = [
[10.0, 10.0, 20.0, 20.0],
......@@ -677,16 +677,16 @@ def test_correctness_resized_crop_bounding_boxes(device, format, top, left, heig
expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
expected_bboxes = torch.tensor(expected_bboxes, device=device)
in_boxes = datapoints.BoundingBoxes(
in_boxes, format=datapoints.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
in_boxes = tv_tensors.BoundingBoxes(
in_boxes, format=tv_tensors.BoundingBoxFormat.XYXY, canvas_size=canvas_size, device=device
)
if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_bounding_box_format(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
if format != tv_tensors.BoundingBoxFormat.XYXY:
in_boxes = convert_bounding_box_format(in_boxes, tv_tensors.BoundingBoxFormat.XYXY, format)
output_boxes, output_canvas_size = F.resized_crop_bounding_boxes(in_boxes, format, top, left, height, width, size)
if format != datapoints.BoundingBoxFormat.XYXY:
output_boxes = convert_bounding_box_format(output_boxes, format, datapoints.BoundingBoxFormat.XYXY)
if format != tv_tensors.BoundingBoxFormat.XYXY:
output_boxes = convert_bounding_box_format(output_boxes, format, tv_tensors.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_canvas_size, size)
......@@ -713,14 +713,14 @@ def test_correctness_pad_bounding_boxes(device, padding):
dtype = bbox.dtype
bbox = (
bbox.clone()
if format == datapoints.BoundingBoxFormat.XYXY
else convert_bounding_box_format(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
if format == tv_tensors.BoundingBoxFormat.XYXY
else convert_bounding_box_format(bbox, old_format=format, new_format=tv_tensors.BoundingBoxFormat.XYXY)
)
bbox[0::2] += pad_left
bbox[1::2] += pad_up
bbox = convert_bounding_box_format(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
bbox = convert_bounding_box_format(bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format)
if bbox.dtype != dtype:
# Temporary cast to original dtype
# e.g. float32 -> int
......@@ -785,7 +785,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
]
)
bbox_xyxy = convert_bounding_box_format(bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY)
bbox_xyxy = convert_bounding_box_format(bbox, old_format=format_, new_format=tv_tensors.BoundingBoxFormat.XYXY)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
......@@ -807,7 +807,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
)
out_bbox = torch.from_numpy(out_bbox)
out_bbox = convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_
out_bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format_
)
return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
......@@ -846,7 +846,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
def test_correctness_center_crop_bounding_boxes(device, output_size):
def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
dtype = bbox.dtype
bbox = convert_bounding_box_format(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
bbox = convert_bounding_box_format(bbox.float(), format_, tv_tensors.BoundingBoxFormat.XYWH)
if len(output_size_) == 1:
output_size_.append(output_size_[-1])
......@@ -860,7 +860,7 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
bbox[3].item(),
]
out_bbox = torch.tensor(out_bbox)
out_bbox = convert_bounding_box_format(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
out_bbox = convert_bounding_box_format(out_bbox, tv_tensors.BoundingBoxFormat.XYWH, format_)
out_bbox = clamp_bounding_boxes(out_bbox, format=format_, canvas_size=output_size)
return out_bbox.to(dtype=dtype, device=bbox.device)
......@@ -958,7 +958,7 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize,
torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
)
image = datapoints.Image(tensor)
image = tv_tensors.Image(tensor)
out = fn(image, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
......
......@@ -36,7 +36,7 @@ from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping
......@@ -167,7 +167,7 @@ def check_kernel(
def _check_functional_scripted_smoke(functional, input, *args, **kwargs):
"""Checks if the functional can be scripted and the scripted version can be called without error."""
if not isinstance(input, datapoints.Image):
if not isinstance(input, tv_tensors.Image):
return
functional_scripted = _script(functional)
......@@ -187,7 +187,7 @@ def check_functional(functional, input, *args, check_scripted_smoke=True, **kwar
assert isinstance(output, type(input))
if isinstance(input, datapoints.BoundingBoxes):
if isinstance(input, tv_tensors.BoundingBoxes):
assert output.format == input.format
if check_scripted_smoke:
......@@ -199,11 +199,11 @@ def check_functional_kernel_signature_match(functional, *, kernel, input_type):
functional_params = list(inspect.signature(functional).parameters.values())[1:]
kernel_params = list(inspect.signature(kernel).parameters.values())[1:]
if issubclass(input_type, datapoints.Datapoint):
# We filter out metadata that is implicitly passed to the functional through the input datapoint, but has to be
if issubclass(input_type, tv_tensors.TVTensor):
# We filter out metadata that is implicitly passed to the functional through the input tv_tensor, but has to be
# explicitly passed to the kernel.
explicit_metadata = {
datapoints.BoundingBoxes: {"format", "canvas_size"},
tv_tensors.BoundingBoxes: {"format", "canvas_size"},
}
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
......@@ -264,7 +264,7 @@ def check_transform(transform, input, check_v1_compatibility=True):
output = transform(input)
assert isinstance(output, type(input))
if isinstance(input, datapoints.BoundingBoxes):
if isinstance(input, tv_tensors.BoundingBoxes):
assert output.format == input.format
if check_v1_compatibility:
......@@ -362,7 +362,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new
input_xyxy = F.convert_bounding_box_format(
bounding_boxes.to(torch.float64, copy=True),
old_format=format,
new_format=datapoints.BoundingBoxFormat.XYXY,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
inplace=True,
)
x1, y1, x2, y2 = input_xyxy.squeeze(0).tolist()
......@@ -387,7 +387,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new
)
output = F.convert_bounding_box_format(
output_xyxy, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format
output_xyxy, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format
)
if clamp:
......@@ -400,7 +400,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new
return output
return datapoints.BoundingBoxes(
return tv_tensors.BoundingBoxes(
torch.cat([affine_bounding_boxes(b) for b in bounding_boxes.reshape(-1, 4).unbind()], dim=0).reshape(
bounding_boxes.shape
),
......@@ -479,7 +479,7 @@ class TestResize:
check_scripted_vs_eager=not isinstance(size, int),
)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("size", OUTPUT_SIZES)
@pytest.mark.parametrize("use_max_size", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
......@@ -529,10 +529,10 @@ class TestResize:
[
(F.resize_image, torch.Tensor),
(F._resize_image_pil, PIL.Image.Image),
(F.resize_image, datapoints.Image),
(F.resize_bounding_boxes, datapoints.BoundingBoxes),
(F.resize_mask, datapoints.Mask),
(F.resize_video, datapoints.Video),
(F.resize_image, tv_tensors.Image),
(F.resize_bounding_boxes, tv_tensors.BoundingBoxes),
(F.resize_mask, tv_tensors.Mask),
(F.resize_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
......@@ -605,7 +605,7 @@ class TestResize:
new_canvas_size=(new_height, new_width),
)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("size", OUTPUT_SIZES)
@pytest.mark.parametrize("use_max_size", [True, False])
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
......@@ -734,9 +734,9 @@ class TestResize:
# This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there
# is a good reason to break this, feel free to downgrade to an equality check.
if isinstance(input, datapoints.Datapoint):
if isinstance(input, tv_tensors.TVTensor):
# We can't test identity directly, since that checks for the identity of the Python object. Since all
# datapoints unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check
# tv_tensors unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check
# that the underlying storage is the same
assert output.data_ptr() == input.data_ptr()
else:
......@@ -782,7 +782,7 @@ class TestResize:
)
if emulate_channels_last:
image = datapoints.wrap(image.view(*batch_dims, *image.shape[-3:]), like=image)
image = tv_tensors.wrap(image.view(*batch_dims, *image.shape[-3:]), like=image)
return image
......@@ -833,7 +833,7 @@ class TestHorizontalFlip:
def test_kernel_image_tensor(self, dtype, device):
check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device))
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, format, dtype, device):
......@@ -864,10 +864,10 @@ class TestHorizontalFlip:
[
(F.horizontal_flip_image, torch.Tensor),
(F._horizontal_flip_image_pil, PIL.Image.Image),
(F.horizontal_flip_image, datapoints.Image),
(F.horizontal_flip_bounding_boxes, datapoints.BoundingBoxes),
(F.horizontal_flip_mask, datapoints.Mask),
(F.horizontal_flip_video, datapoints.Video),
(F.horizontal_flip_image, tv_tensors.Image),
(F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.horizontal_flip_mask, tv_tensors.Mask),
(F.horizontal_flip_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
......@@ -902,7 +902,7 @@ class TestHorizontalFlip:
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize(
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
)
......@@ -999,7 +999,7 @@ class TestAffine:
shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"],
center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
......@@ -1032,10 +1032,10 @@ class TestAffine:
[
(F.affine_image, torch.Tensor),
(F._affine_image_pil, PIL.Image.Image),
(F.affine_image, datapoints.Image),
(F.affine_bounding_boxes, datapoints.BoundingBoxes),
(F.affine_mask, datapoints.Mask),
(F.affine_video, datapoints.Video),
(F.affine_image, tv_tensors.Image),
(F.affine_bounding_boxes, tv_tensors.BoundingBoxes),
(F.affine_mask, tv_tensors.Mask),
(F.affine_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
......@@ -1148,7 +1148,7 @@ class TestAffine:
),
)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
@pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"])
......@@ -1176,7 +1176,7 @@ class TestAffine:
torch.testing.assert_close(actual, expected)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@pytest.mark.parametrize("seed", list(range(5)))
def test_transform_bounding_boxes_correctness(self, format, center, seed):
......@@ -1283,7 +1283,7 @@ class TestVerticalFlip:
def test_kernel_image_tensor(self, dtype, device):
check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device))
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, format, dtype, device):
......@@ -1314,10 +1314,10 @@ class TestVerticalFlip:
[
(F.vertical_flip_image, torch.Tensor),
(F._vertical_flip_image_pil, PIL.Image.Image),
(F.vertical_flip_image, datapoints.Image),
(F.vertical_flip_bounding_boxes, datapoints.BoundingBoxes),
(F.vertical_flip_mask, datapoints.Mask),
(F.vertical_flip_video, datapoints.Video),
(F.vertical_flip_image, tv_tensors.Image),
(F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.vertical_flip_mask, tv_tensors.Mask),
(F.vertical_flip_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
......@@ -1350,7 +1350,7 @@ class TestVerticalFlip:
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_bounding_boxes_correctness(self, format, fn):
bounding_boxes = make_bounding_boxes(format=format)
......@@ -1419,7 +1419,7 @@ class TestRotate:
expand=[False, True],
center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
......@@ -1456,10 +1456,10 @@ class TestRotate:
[
(F.rotate_image, torch.Tensor),
(F._rotate_image_pil, PIL.Image.Image),
(F.rotate_image, datapoints.Image),
(F.rotate_bounding_boxes, datapoints.BoundingBoxes),
(F.rotate_mask, datapoints.Mask),
(F.rotate_video, datapoints.Video),
(F.rotate_image, tv_tensors.Image),
(F.rotate_bounding_boxes, tv_tensors.BoundingBoxes),
(F.rotate_mask, tv_tensors.Mask),
(F.rotate_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
......@@ -1553,11 +1553,11 @@ class TestRotate:
def _recenter_bounding_boxes_after_expand(self, bounding_boxes, *, recenter_xy):
x, y = recenter_xy
if bounding_boxes.format is datapoints.BoundingBoxFormat.XYXY:
if bounding_boxes.format is tv_tensors.BoundingBoxFormat.XYXY:
translate = [x, y, x, y]
else:
translate = [x, y, 0.0, 0.0]
return datapoints.wrap(
return tv_tensors.wrap(
(bounding_boxes.to(torch.float64) - torch.tensor(translate)).to(bounding_boxes.dtype), like=bounding_boxes
)
......@@ -1590,7 +1590,7 @@ class TestRotate:
bounding_boxes
)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
......@@ -1603,7 +1603,7 @@ class TestRotate:
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(F.get_size(actual), F.get_size(expected), atol=2 if expand else 0, rtol=0)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("expand", [False, True])
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
@pytest.mark.parametrize("seed", list(range(5)))
......@@ -1861,7 +1861,7 @@ class TestToDtype:
# make sure "others" works as a catch-all and that None means no conversion
sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
out = transforms.ToDtype(dtype={datapoints.Mask: torch.int64, "others": None})(sample)
out = transforms.ToDtype(dtype={tv_tensors.Mask: torch.int64, "others": None})(sample)
assert out["inpt"].dtype == inpt_dtype
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype != mask_dtype
......@@ -1874,7 +1874,7 @@ class TestToDtype:
sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
out = transforms.ToDtype(
dtype={type(sample["inpt"]): torch.float32, datapoints.Mask: torch.int64, "others": None}, scale=True
dtype={type(sample["inpt"]): torch.float32, tv_tensors.Mask: torch.int64, "others": None}, scale=True
)(sample)
assert out["inpt"].dtype != inpt_dtype
assert out["inpt"].dtype == torch.float32
......@@ -1888,9 +1888,9 @@ class TestToDtype:
sample, inpt_dtype, bbox_dtype, mask_dtype = self.make_inpt_with_bbox_and_mask(make_input)
with pytest.raises(ValueError, match="No dtype was specified for"):
out = transforms.ToDtype(dtype={datapoints.Mask: torch.float32})(sample)
out = transforms.ToDtype(dtype={tv_tensors.Mask: torch.float32})(sample)
with pytest.warns(UserWarning, match=re.escape("plain `torch.Tensor` will *not* be transformed")):
transforms.ToDtype(dtype={torch.Tensor: torch.float32, datapoints.Image: torch.float32})
transforms.ToDtype(dtype={torch.Tensor: torch.float32, tv_tensors.Image: torch.float32})
with pytest.warns(UserWarning, match="no scaling will be done"):
out = transforms.ToDtype(dtype={"others": None}, scale=True)(sample)
assert out["inpt"].dtype == inpt_dtype
......@@ -1923,8 +1923,8 @@ class TestAdjustBrightness:
[
(F.adjust_brightness_image, torch.Tensor),
(F._adjust_brightness_image_pil, PIL.Image.Image),
(F.adjust_brightness_image, datapoints.Image),
(F.adjust_brightness_video, datapoints.Video),
(F.adjust_brightness_image, tv_tensors.Image),
(F.adjust_brightness_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
......@@ -2028,8 +2028,8 @@ class TestCutMixMixUp:
for input_with_bad_type in (
F.to_pil_image(imgs[0]),
datapoints.Mask(torch.rand(12, 12)),
datapoints.BoundingBoxes(torch.rand(2, 4), format="XYXY", canvas_size=12),
tv_tensors.Mask(torch.rand(12, 12)),
tv_tensors.BoundingBoxes(torch.rand(2, 4), format="XYXY", canvas_size=12),
):
with pytest.raises(ValueError, match="does not support PIL images, "):
cutmix_mixup(input_with_bad_type)
......@@ -2172,12 +2172,12 @@ class TestShapeGetters:
class TestRegisterKernel:
@pytest.mark.parametrize("functional", (F.resize, "resize"))
def test_register_kernel(self, functional):
class CustomDatapoint(datapoints.Datapoint):
class CustomTVTensor(tv_tensors.TVTensor):
pass
kernel_was_called = False
@F.register_kernel(functional, CustomDatapoint)
@F.register_kernel(functional, CustomTVTensor)
def new_resize(dp, *args, **kwargs):
nonlocal kernel_was_called
kernel_was_called = True
......@@ -2185,38 +2185,38 @@ class TestRegisterKernel:
t = transforms.Resize(size=(224, 224), antialias=True)
my_dp = CustomDatapoint(torch.rand(3, 10, 10))
my_dp = CustomTVTensor(torch.rand(3, 10, 10))
out = t(my_dp)
assert out is my_dp
assert kernel_was_called
# Sanity check to make sure we didn't override the kernel of other types
t(torch.rand(3, 10, 10)).shape == (3, 224, 224)
t(datapoints.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
t(tv_tensors.Image(torch.rand(3, 10, 10))).shape == (3, 224, 224)
def test_errors(self):
with pytest.raises(ValueError, match="Could not find functional with name"):
F.register_kernel("bad_name", datapoints.Image)
F.register_kernel("bad_name", tv_tensors.Image)
with pytest.raises(ValueError, match="Kernels can only be registered on functionals"):
F.register_kernel(datapoints.Image, F.resize)
F.register_kernel(tv_tensors.Image, F.resize)
with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
F.register_kernel(F.resize, object)
with pytest.raises(ValueError, match="cannot be registered for the builtin datapoint classes"):
F.register_kernel(F.resize, datapoints.Image)(F.resize_image)
with pytest.raises(ValueError, match="cannot be registered for the builtin tv_tensor classes"):
F.register_kernel(F.resize, tv_tensors.Image)(F.resize_image)
class CustomDatapoint(datapoints.Datapoint):
class CustomTVTensor(tv_tensors.TVTensor):
pass
def resize_custom_datapoint():
def resize_custom_tv_tensor():
pass
F.register_kernel(F.resize, CustomDatapoint)(resize_custom_datapoint)
F.register_kernel(F.resize, CustomTVTensor)(resize_custom_tv_tensor)
with pytest.raises(ValueError, match="already has a kernel registered for type"):
F.register_kernel(F.resize, CustomDatapoint)(resize_custom_datapoint)
F.register_kernel(F.resize, CustomTVTensor)(resize_custom_tv_tensor)
class TestGetKernel:
......@@ -2225,10 +2225,10 @@ class TestGetKernel:
KERNELS = {
torch.Tensor: F.resize_image,
PIL.Image.Image: F._resize_image_pil,
datapoints.Image: F.resize_image,
datapoints.BoundingBoxes: F.resize_bounding_boxes,
datapoints.Mask: F.resize_mask,
datapoints.Video: F.resize_video,
tv_tensors.Image: F.resize_image,
tv_tensors.BoundingBoxes: F.resize_bounding_boxes,
tv_tensors.Mask: F.resize_mask,
tv_tensors.Video: F.resize_video,
}
@pytest.mark.parametrize("input_type", [str, int, object])
......@@ -2244,57 +2244,57 @@ class TestGetKernel:
pass
for input_type, kernel in self.KERNELS.items():
_register_kernel_internal(resize_with_pure_kernels, input_type, datapoint_wrapper=False)(kernel)
_register_kernel_internal(resize_with_pure_kernels, input_type, tv_tensor_wrapper=False)(kernel)
assert _get_kernel(resize_with_pure_kernels, input_type) is kernel
def test_builtin_datapoint_subclass(self):
def test_builtin_tv_tensor_subclass(self):
# We cannot use F.resize together with self.KERNELS mapping here directly here, since this is only the
# ideal wrapping. Practically, we have an intermediate wrapper layer. Thus, we create a new resize functional
# here, register the kernels without wrapper, and check if subclasses of our builtin datapoints get dispatched
# here, register the kernels without wrapper, and check if subclasses of our builtin tv_tensors get dispatched
# to the kernel of the corresponding superclass
def resize_with_pure_kernels():
pass
class MyImage(datapoints.Image):
class MyImage(tv_tensors.Image):
pass
class MyBoundingBoxes(datapoints.BoundingBoxes):
class MyBoundingBoxes(tv_tensors.BoundingBoxes):
pass
class MyMask(datapoints.Mask):
class MyMask(tv_tensors.Mask):
pass
class MyVideo(datapoints.Video):
class MyVideo(tv_tensors.Video):
pass
for custom_datapoint_subclass in [
for custom_tv_tensor_subclass in [
MyImage,
MyBoundingBoxes,
MyMask,
MyVideo,
]:
builtin_datapoint_class = custom_datapoint_subclass.__mro__[1]
builtin_datapoint_kernel = self.KERNELS[builtin_datapoint_class]
_register_kernel_internal(resize_with_pure_kernels, builtin_datapoint_class, datapoint_wrapper=False)(
builtin_datapoint_kernel
builtin_tv_tensor_class = custom_tv_tensor_subclass.__mro__[1]
builtin_tv_tensor_kernel = self.KERNELS[builtin_tv_tensor_class]
_register_kernel_internal(resize_with_pure_kernels, builtin_tv_tensor_class, tv_tensor_wrapper=False)(
builtin_tv_tensor_kernel
)
assert _get_kernel(resize_with_pure_kernels, custom_datapoint_subclass) is builtin_datapoint_kernel
assert _get_kernel(resize_with_pure_kernels, custom_tv_tensor_subclass) is builtin_tv_tensor_kernel
def test_datapoint_subclass(self):
class MyDatapoint(datapoints.Datapoint):
def test_tv_tensor_subclass(self):
class MyTVTensor(tv_tensors.TVTensor):
pass
with pytest.raises(TypeError, match="supports inputs of type"):
_get_kernel(F.resize, MyDatapoint)
_get_kernel(F.resize, MyTVTensor)
def resize_my_datapoint():
def resize_my_tv_tensor():
pass
_register_kernel_internal(F.resize, MyDatapoint, datapoint_wrapper=False)(resize_my_datapoint)
_register_kernel_internal(F.resize, MyTVTensor, tv_tensor_wrapper=False)(resize_my_tv_tensor)
assert _get_kernel(F.resize, MyDatapoint) is resize_my_datapoint
assert _get_kernel(F.resize, MyTVTensor) is resize_my_tv_tensor
def test_pil_image_subclass(self):
opened_image = PIL.Image.open(Path(__file__).parent / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")
......@@ -2342,8 +2342,8 @@ class TestPermuteChannels:
[
(F.permute_channels_image, torch.Tensor),
(F._permute_channels_image_pil, PIL.Image.Image),
(F.permute_channels_image, datapoints.Image),
(F.permute_channels_video, datapoints.Video),
(F.permute_channels_image, tv_tensors.Image),
(F.permute_channels_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
......@@ -2352,7 +2352,7 @@ class TestPermuteChannels:
def reference_image_correctness(self, image, permutation):
channel_images = image.split(1, dim=-3)
permuted_channel_images = [channel_images[channel_idx] for channel_idx in permutation]
return datapoints.Image(torch.concat(permuted_channel_images, dim=-3))
return tv_tensors.Image(torch.concat(permuted_channel_images, dim=-3))
@pytest.mark.parametrize("permutation", [[2, 0, 1], [1, 2, 0], [2, 0, 1], [0, 1, 2]])
@pytest.mark.parametrize("batch_dims", [(), (2,), (2, 1)])
......@@ -2392,7 +2392,7 @@ class TestElastic:
check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
)
@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_boxes(self, format, dtype, device):
......@@ -2428,10 +2428,10 @@ class TestElastic:
[
(F.elastic_image, torch.Tensor),
(F._elastic_image_pil, PIL.Image.Image),
(F.elastic_image, datapoints.Image),
(F.elastic_bounding_boxes, datapoints.BoundingBoxes),
(F.elastic_mask, datapoints.Mask),
(F.elastic_video, datapoints.Video),
(F.elastic_image, tv_tensors.Image),
(F.elastic_bounding_boxes, tv_tensors.BoundingBoxes),
(F.elastic_mask, tv_tensors.Mask),
(F.elastic_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
......@@ -2481,7 +2481,7 @@ class TestToPureTensor:
out = transforms.ToPureTensor()(input)
for input_value, out_value in zip(input.values(), out.values()):
if isinstance(input_value, datapoints.Datapoint):
assert isinstance(out_value, torch.Tensor) and not isinstance(out_value, datapoints.Datapoint)
if isinstance(input_value, tv_tensors.TVTensor):
assert isinstance(out_value, torch.Tensor) and not isinstance(out_value, tv_tensors.TVTensor)
else:
assert isinstance(out_value, type(input_value))
......@@ -6,46 +6,46 @@ import torch
import torchvision.transforms.v2._utils
from common_utils import DEFAULT_SIZE, make_bounding_boxes, make_detection_mask, make_image
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.v2._utils import has_all, has_any
from torchvision.transforms.v2.functional import to_pil_image
IMAGE = make_image(DEFAULT_SIZE, color_space="RGB")
BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=datapoints.BoundingBoxFormat.XYXY)
BOUNDING_BOX = make_bounding_boxes(DEFAULT_SIZE, format=tv_tensors.BoundingBoxFormat.XYXY)
MASK = make_detection_mask(DEFAULT_SIZE)
@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes, datapoints.Mask), True),
((MASK,), (datapoints.Image, datapoints.BoundingBoxes), False),
((BOUNDING_BOX,), (datapoints.Image, datapoints.Mask), False),
((IMAGE,), (datapoints.BoundingBoxes, datapoints.Mask), False),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
((MASK,), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
((BOUNDING_BOX,), (tv_tensors.Image, tv_tensors.Mask), False),
((IMAGE,), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
(
(IMAGE, BOUNDING_BOX, MASK),
(datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
True,
),
((), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, tv_tensors.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
((IMAGE,), (tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor), True),
(
(torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
(tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
True,
),
(
(to_pil_image(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
(tv_tensors.Image, PIL.Image.Image, torchvision.transforms.v2._utils.is_pure_tensor),
True,
),
],
......@@ -57,31 +57,31 @@ def test_has_any(sample, types, expected):
@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (datapoints.BoundingBoxes, datapoints.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), True),
(
(IMAGE, BOUNDING_BOX, MASK),
(datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
True,
),
((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes), False),
((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.Mask), False),
((IMAGE, MASK), (datapoints.BoundingBoxes, datapoints.Mask), False),
((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes), False),
((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.Mask), False),
((IMAGE, MASK), (tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
(
(IMAGE, BOUNDING_BOX, MASK),
(datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask),
(tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask),
True,
),
((BOUNDING_BOX, MASK), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
((IMAGE, MASK), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
((IMAGE, BOUNDING_BOX), (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask), False),
((BOUNDING_BOX, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, MASK), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
((IMAGE, BOUNDING_BOX), (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask), False),
(
(IMAGE, BOUNDING_BOX, MASK),
(lambda obj: isinstance(obj, (datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)),),
(lambda obj: isinstance(obj, (tv_tensors.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask)),),
True,
),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
......
......@@ -5,7 +5,7 @@ import torch
from common_utils import assert_equal, make_bounding_boxes, make_image, make_segmentation_mask, make_video
from PIL import Image
from torchvision import datapoints
from torchvision import tv_tensors
@pytest.fixture(autouse=True)
......@@ -13,40 +13,40 @@ def restore_tensor_return_type():
# This is for security, as we should already be restoring the default manually in each test anyway
# (at least at the time of writing...)
yield
datapoints.set_return_type("Tensor")
tv_tensors.set_return_type("Tensor")
@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)])
def test_image_instance(data):
image = datapoints.Image(data)
image = tv_tensors.Image(data)
assert isinstance(image, torch.Tensor)
assert image.ndim == 3 and image.shape[0] == 3
@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)])
def test_mask_instance(data):
mask = datapoints.Mask(data)
mask = tv_tensors.Mask(data)
assert isinstance(mask, torch.Tensor)
assert mask.ndim == 3 and mask.shape[0] == 1
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
@pytest.mark.parametrize(
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
"format", ["XYXY", "CXCYWH", tv_tensors.BoundingBoxFormat.XYXY, tv_tensors.BoundingBoxFormat.XYWH]
)
def test_bbox_instance(data, format):
bboxes = datapoints.BoundingBoxes(data, format=format, canvas_size=(32, 32))
bboxes = tv_tensors.BoundingBoxes(data, format=format, canvas_size=(32, 32))
assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[(format.upper())]
format = tv_tensors.BoundingBoxFormat[(format.upper())]
assert bboxes.format == format
def test_bbox_dim_error():
data_3d = [[[1, 2, 3, 4]]]
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
datapoints.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))
tv_tensors.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))
@pytest.mark.parametrize(
......@@ -64,8 +64,8 @@ def test_bbox_dim_error():
],
)
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
datapoint = datapoints.Image(data, requires_grad=input_requires_grad)
assert datapoint.requires_grad is expected_requires_grad
tv_tensor = tv_tensors.Image(data, requires_grad=input_requires_grad)
assert tv_tensor.requires_grad is expected_requires_grad
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
......@@ -75,7 +75,7 @@ def test_isinstance(make_input):
def test_wrapping_no_copy():
tensor = torch.rand(3, 16, 16)
image = datapoints.Image(tensor)
image = tv_tensors.Image(tensor)
assert image.data_ptr() == tensor.data_ptr()
......@@ -91,25 +91,25 @@ def test_to_wrapping(make_input):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
def test_to_datapoint_reference(make_input, return_type):
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_to_tv_tensor_reference(make_input, return_type):
tensor = torch.rand((3, 16, 16), dtype=torch.float64)
dp = make_input()
with datapoints.set_return_type(return_type):
with tv_tensors.set_return_type(return_type):
tensor_to = tensor.to(dp)
assert type(tensor_to) is (type(dp) if return_type == "datapoint" else torch.Tensor)
assert type(tensor_to) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
assert tensor_to.dtype is dp.dtype
assert type(tensor) is torch.Tensor
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_clone_wrapping(make_input, return_type):
dp = make_input()
with datapoints.set_return_type(return_type):
with tv_tensors.set_return_type(return_type):
dp_clone = dp.clone()
assert type(dp_clone) is type(dp)
......@@ -117,13 +117,13 @@ def test_clone_wrapping(make_input, return_type):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_requires_grad__wrapping(make_input, return_type):
dp = make_input(dtype=torch.float)
assert not dp.requires_grad
with datapoints.set_return_type(return_type):
with tv_tensors.set_return_type(return_type):
dp_requires_grad = dp.requires_grad_(True)
assert type(dp_requires_grad) is type(dp)
......@@ -132,54 +132,54 @@ def test_requires_grad__wrapping(make_input, return_type):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_detach_wrapping(make_input, return_type):
dp = make_input(dtype=torch.float).requires_grad_(True)
with datapoints.set_return_type(return_type):
with tv_tensors.set_return_type(return_type):
dp_detached = dp.detach()
assert type(dp_detached) is type(dp)
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_force_subclass_with_metadata(return_type):
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and datapoints with metadata
# Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata
# Largely the same as above, we additionally check that the metadata is preserved
format, canvas_size = "XYXY", (32, 32)
bbox = datapoints.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)
bbox = tv_tensors.BoundingBoxes([[0, 0, 5, 5], [2, 2, 7, 7]], format=format, canvas_size=canvas_size)
datapoints.set_return_type(return_type)
tv_tensors.set_return_type(return_type)
bbox = bbox.clone()
if return_type == "datapoint":
if return_type == "tv_tensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.to(torch.float64)
if return_type == "datapoint":
if return_type == "tv_tensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
bbox = bbox.detach()
if return_type == "datapoint":
if return_type == "tv_tensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert not bbox.requires_grad
bbox.requires_grad_(True)
if return_type == "datapoint":
if return_type == "tv_tensor":
assert bbox.format, bbox.canvas_size == (format, canvas_size)
assert bbox.requires_grad
datapoints.set_return_type("tensor")
tv_tensors.set_return_type("tensor")
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_other_op_no_wrapping(make_input, return_type):
dp = make_input()
with datapoints.set_return_type(return_type):
with tv_tensors.set_return_type(return_type):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2
assert type(output) is (type(dp) if return_type == "datapoint" else torch.Tensor)
assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
......@@ -200,15 +200,15 @@ def test_no_tensor_output_op_no_wrapping(make_input, op):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
def test_inplace_op_no_wrapping(make_input, return_type):
dp = make_input()
original_type = type(dp)
with datapoints.set_return_type(return_type):
with tv_tensors.set_return_type(return_type):
output = dp.add_(0)
assert type(output) is (type(dp) if return_type == "datapoint" else torch.Tensor)
assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
assert type(dp) is original_type
......@@ -219,7 +219,7 @@ def test_wrap(make_input):
# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2
dp_new = datapoints.wrap(output, like=dp)
dp_new = tv_tensors.wrap(output, like=dp)
assert type(dp_new) is type(dp)
assert dp_new.data_ptr() == output.data_ptr()
......@@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad):
@pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video])
@pytest.mark.parametrize("return_type", ["Tensor", "datapoint"])
@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"])
@pytest.mark.parametrize(
"op",
(
......@@ -265,10 +265,10 @@ def test_deepcopy(make_input, requires_grad):
def test_usual_operations(make_input, return_type, op):
dp = make_input()
with datapoints.set_return_type(return_type):
with tv_tensors.set_return_type(return_type):
out = op(dp)
assert type(out) is (type(dp) if return_type == "datapoint" else torch.Tensor)
if isinstance(dp, datapoints.BoundingBoxes) and return_type == "datapoint":
assert type(out) is (type(dp) if return_type == "tv_tensor" else torch.Tensor)
if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "tv_tensor":
assert hasattr(out, "format")
assert hasattr(out, "canvas_size")
......@@ -286,22 +286,22 @@ def test_set_return_type():
assert type(img + 3) is torch.Tensor
with datapoints.set_return_type("datapoint"):
assert type(img + 3) is datapoints.Image
with tv_tensors.set_return_type("tv_tensor"):
assert type(img + 3) is tv_tensors.Image
assert type(img + 3) is torch.Tensor
datapoints.set_return_type("datapoint")
assert type(img + 3) is datapoints.Image
tv_tensors.set_return_type("tv_tensor")
assert type(img + 3) is tv_tensors.Image
with datapoints.set_return_type("tensor"):
with tv_tensors.set_return_type("tensor"):
assert type(img + 3) is torch.Tensor
with datapoints.set_return_type("datapoint"):
assert type(img + 3) is datapoints.Image
datapoints.set_return_type("tensor")
with tv_tensors.set_return_type("tv_tensor"):
assert type(img + 3) is tv_tensors.Image
tv_tensors.set_return_type("tensor")
assert type(img + 3) is torch.Tensor
assert type(img + 3) is torch.Tensor
# Exiting a context manager will restore the return type as it was prior to entering it,
# regardless of whether the "global" datapoints.set_return_type() was called within the context manager.
assert type(img + 3) is datapoints.Image
# regardless of whether the "global" tv_tensors.set_return_type() was called within the context manager.
assert type(img + 3) is tv_tensors.Image
datapoints.set_return_type("tensor")
tv_tensors.set_return_type("tensor")
......@@ -2,7 +2,7 @@ import collections.abc
import pytest
import torchvision.transforms.v2.functional as F
from torchvision import datapoints
from torchvision import tv_tensors
from transforms_v2_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
from transforms_v2_legacy_utils import InfoBase, TestMark
......@@ -44,19 +44,19 @@ class DispatcherInfo(InfoBase):
self.pil_kernel_info = pil_kernel_info
kernel_infos = {}
for datapoint_type, kernel in self.kernels.items():
for tv_tensor_type, kernel in self.kernels.items():
kernel_info = self._KERNEL_INFO_MAP.get(kernel)
if not kernel_info:
raise pytest.UsageError(
f"Can't register {kernel.__name__} for type {datapoint_type} since there is no `KernelInfo` for it. "
f"Can't register {kernel.__name__} for type {tv_tensor_type} since there is no `KernelInfo` for it. "
f"Please add a `KernelInfo` for it in `transforms_v2_kernel_infos.py`."
)
kernel_infos[datapoint_type] = kernel_info
kernel_infos[tv_tensor_type] = kernel_info
self.kernel_infos = kernel_infos
def sample_inputs(self, *datapoint_types, filter_metadata=True):
for datapoint_type in datapoint_types or self.kernel_infos.keys():
kernel_info = self.kernel_infos.get(datapoint_type)
def sample_inputs(self, *tv_tensor_types, filter_metadata=True):
for tv_tensor_type in tv_tensor_types or self.kernel_infos.keys():
kernel_info = self.kernel_infos.get(tv_tensor_type)
if not kernel_info:
raise pytest.UsageError(f"There is no kernel registered for type {type.__name__}")
......@@ -69,12 +69,12 @@ class DispatcherInfo(InfoBase):
import itertools
for args_kwargs in sample_inputs:
if hasattr(datapoint_type, "__annotations__"):
if hasattr(tv_tensor_type, "__annotations__"):
for name in itertools.chain(
datapoint_type.__annotations__.keys(),
tv_tensor_type.__annotations__.keys(),
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
# per-dispatcher level. However, so far there is no option for that.
(f"old_{name}" for name in datapoint_type.__annotations__.keys()),
(f"old_{name}" for name in tv_tensor_type.__annotations__.keys()),
):
if name in args_kwargs.kwargs:
del args_kwargs.kwargs[name]
......@@ -97,9 +97,9 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
)
skip_dispatch_datapoint = TestMark(
("TestDispatchers", "test_dispatch_datapoint"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary datapoint dispatch."),
skip_dispatch_tv_tensor = TestMark(
("TestDispatchers", "test_dispatch_tv_tensor"),
pytest.mark.skip(reason="Dispatcher doesn't support arbitrary tv_tensor dispatch."),
)
multi_crop_skips = [
......@@ -107,9 +107,9 @@ multi_crop_skips = [
("TestDispatchers", test_name),
pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."),
)
for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"]
for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_tv_tensor_output_type"]
]
multi_crop_skips.append(skip_dispatch_datapoint)
multi_crop_skips.append(skip_dispatch_tv_tensor)
def xfails_pil(reason, *, condition=None):
......@@ -142,30 +142,30 @@ DISPATCHER_INFOS = [
DispatcherInfo(
F.crop,
kernels={
datapoints.Image: F.crop_image,
datapoints.Video: F.crop_video,
datapoints.BoundingBoxes: F.crop_bounding_boxes,
datapoints.Mask: F.crop_mask,
tv_tensors.Image: F.crop_image,
tv_tensors.Video: F.crop_video,
tv_tensors.BoundingBoxes: F.crop_bounding_boxes,
tv_tensors.Mask: F.crop_mask,
},
pil_kernel_info=PILKernelInfo(F._crop_image_pil, kernel_name="crop_image_pil"),
),
DispatcherInfo(
F.resized_crop,
kernels={
datapoints.Image: F.resized_crop_image,
datapoints.Video: F.resized_crop_video,
datapoints.BoundingBoxes: F.resized_crop_bounding_boxes,
datapoints.Mask: F.resized_crop_mask,
tv_tensors.Image: F.resized_crop_image,
tv_tensors.Video: F.resized_crop_video,
tv_tensors.BoundingBoxes: F.resized_crop_bounding_boxes,
tv_tensors.Mask: F.resized_crop_mask,
},
pil_kernel_info=PILKernelInfo(F._resized_crop_image_pil),
),
DispatcherInfo(
F.pad,
kernels={
datapoints.Image: F.pad_image,
datapoints.Video: F.pad_video,
datapoints.BoundingBoxes: F.pad_bounding_boxes,
datapoints.Mask: F.pad_mask,
tv_tensors.Image: F.pad_image,
tv_tensors.Video: F.pad_video,
tv_tensors.BoundingBoxes: F.pad_bounding_boxes,
tv_tensors.Mask: F.pad_mask,
},
pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[
......@@ -184,10 +184,10 @@ DISPATCHER_INFOS = [
DispatcherInfo(
F.perspective,
kernels={
datapoints.Image: F.perspective_image,
datapoints.Video: F.perspective_video,
datapoints.BoundingBoxes: F.perspective_bounding_boxes,
datapoints.Mask: F.perspective_mask,
tv_tensors.Image: F.perspective_image,
tv_tensors.Video: F.perspective_video,
tv_tensors.BoundingBoxes: F.perspective_bounding_boxes,
tv_tensors.Mask: F.perspective_mask,
},
pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
test_marks=[
......@@ -198,10 +198,10 @@ DISPATCHER_INFOS = [
DispatcherInfo(
F.elastic,
kernels={
datapoints.Image: F.elastic_image,
datapoints.Video: F.elastic_video,
datapoints.BoundingBoxes: F.elastic_bounding_boxes,
datapoints.Mask: F.elastic_mask,
tv_tensors.Image: F.elastic_image,
tv_tensors.Video: F.elastic_video,
tv_tensors.BoundingBoxes: F.elastic_bounding_boxes,
tv_tensors.Mask: F.elastic_mask,
},
pil_kernel_info=PILKernelInfo(F._elastic_image_pil),
test_marks=[xfail_jit_python_scalar_arg("fill")],
......@@ -209,10 +209,10 @@ DISPATCHER_INFOS = [
DispatcherInfo(
F.center_crop,
kernels={
datapoints.Image: F.center_crop_image,
datapoints.Video: F.center_crop_video,
datapoints.BoundingBoxes: F.center_crop_bounding_boxes,
datapoints.Mask: F.center_crop_mask,
tv_tensors.Image: F.center_crop_image,
tv_tensors.Video: F.center_crop_video,
tv_tensors.BoundingBoxes: F.center_crop_bounding_boxes,
tv_tensors.Mask: F.center_crop_mask,
},
pil_kernel_info=PILKernelInfo(F._center_crop_image_pil),
test_marks=[
......@@ -222,8 +222,8 @@ DISPATCHER_INFOS = [
DispatcherInfo(
F.gaussian_blur,
kernels={
datapoints.Image: F.gaussian_blur_image,
datapoints.Video: F.gaussian_blur_video,
tv_tensors.Image: F.gaussian_blur_image,
tv_tensors.Video: F.gaussian_blur_video,
},
pil_kernel_info=PILKernelInfo(F._gaussian_blur_image_pil),
test_marks=[
......@@ -234,99 +234,99 @@ DISPATCHER_INFOS = [
DispatcherInfo(
F.equalize,
kernels={
datapoints.Image: F.equalize_image,
datapoints.Video: F.equalize_video,
tv_tensors.Image: F.equalize_image,
tv_tensors.Video: F.equalize_video,
},
pil_kernel_info=PILKernelInfo(F._equalize_image_pil, kernel_name="equalize_image_pil"),
),
DispatcherInfo(
F.invert,
kernels={
datapoints.Image: F.invert_image,
datapoints.Video: F.invert_video,
tv_tensors.Image: F.invert_image,
tv_tensors.Video: F.invert_video,
},
pil_kernel_info=PILKernelInfo(F._invert_image_pil, kernel_name="invert_image_pil"),
),
DispatcherInfo(
F.posterize,
kernels={
datapoints.Image: F.posterize_image,
datapoints.Video: F.posterize_video,
tv_tensors.Image: F.posterize_image,
tv_tensors.Video: F.posterize_video,
},
pil_kernel_info=PILKernelInfo(F._posterize_image_pil, kernel_name="posterize_image_pil"),
),
DispatcherInfo(
F.solarize,
kernels={
datapoints.Image: F.solarize_image,
datapoints.Video: F.solarize_video,
tv_tensors.Image: F.solarize_image,
tv_tensors.Video: F.solarize_video,
},
pil_kernel_info=PILKernelInfo(F._solarize_image_pil, kernel_name="solarize_image_pil"),
),
DispatcherInfo(
F.autocontrast,
kernels={
datapoints.Image: F.autocontrast_image,
datapoints.Video: F.autocontrast_video,
tv_tensors.Image: F.autocontrast_image,
tv_tensors.Video: F.autocontrast_video,
},
pil_kernel_info=PILKernelInfo(F._autocontrast_image_pil, kernel_name="autocontrast_image_pil"),
),
DispatcherInfo(
F.adjust_sharpness,
kernels={
datapoints.Image: F.adjust_sharpness_image,
datapoints.Video: F.adjust_sharpness_video,
tv_tensors.Image: F.adjust_sharpness_image,
tv_tensors.Video: F.adjust_sharpness_video,
},
pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
),
DispatcherInfo(
F.erase,
kernels={
datapoints.Image: F.erase_image,
datapoints.Video: F.erase_video,
tv_tensors.Image: F.erase_image,
tv_tensors.Video: F.erase_video,
},
pil_kernel_info=PILKernelInfo(F._erase_image_pil),
test_marks=[
skip_dispatch_datapoint,
skip_dispatch_tv_tensor,
],
),
DispatcherInfo(
F.adjust_contrast,
kernels={
datapoints.Image: F.adjust_contrast_image,
datapoints.Video: F.adjust_contrast_video,
tv_tensors.Image: F.adjust_contrast_image,
tv_tensors.Video: F.adjust_contrast_video,
},
pil_kernel_info=PILKernelInfo(F._adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"),
),
DispatcherInfo(
F.adjust_gamma,
kernels={
datapoints.Image: F.adjust_gamma_image,
datapoints.Video: F.adjust_gamma_video,
tv_tensors.Image: F.adjust_gamma_image,
tv_tensors.Video: F.adjust_gamma_video,
},
pil_kernel_info=PILKernelInfo(F._adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"),
),
DispatcherInfo(
F.adjust_hue,
kernels={
datapoints.Image: F.adjust_hue_image,
datapoints.Video: F.adjust_hue_video,
tv_tensors.Image: F.adjust_hue_image,
tv_tensors.Video: F.adjust_hue_video,
},
pil_kernel_info=PILKernelInfo(F._adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"),
),
DispatcherInfo(
F.adjust_saturation,
kernels={
datapoints.Image: F.adjust_saturation_image,
datapoints.Video: F.adjust_saturation_video,
tv_tensors.Image: F.adjust_saturation_image,
tv_tensors.Video: F.adjust_saturation_video,
},
pil_kernel_info=PILKernelInfo(F._adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"),
),
DispatcherInfo(
F.five_crop,
kernels={
datapoints.Image: F.five_crop_image,
datapoints.Video: F.five_crop_video,
tv_tensors.Image: F.five_crop_image,
tv_tensors.Video: F.five_crop_video,
},
pil_kernel_info=PILKernelInfo(F._five_crop_image_pil),
test_marks=[
......@@ -337,8 +337,8 @@ DISPATCHER_INFOS = [
DispatcherInfo(
F.ten_crop,
kernels={
datapoints.Image: F.ten_crop_image,
datapoints.Video: F.ten_crop_video,
tv_tensors.Image: F.ten_crop_image,
tv_tensors.Video: F.ten_crop_video,
},
test_marks=[
xfail_jit_python_scalar_arg("size"),
......@@ -349,8 +349,8 @@ DISPATCHER_INFOS = [
DispatcherInfo(
F.normalize,
kernels={
datapoints.Image: F.normalize_image,
datapoints.Video: F.normalize_video,
tv_tensors.Image: F.normalize_image,
tv_tensors.Video: F.normalize_video,
},
test_marks=[
xfail_jit_python_scalar_arg("mean"),
......@@ -360,24 +360,24 @@ DISPATCHER_INFOS = [
DispatcherInfo(
F.uniform_temporal_subsample,
kernels={
datapoints.Video: F.uniform_temporal_subsample_video,
tv_tensors.Video: F.uniform_temporal_subsample_video,
},
test_marks=[
skip_dispatch_datapoint,
skip_dispatch_tv_tensor,
],
),
DispatcherInfo(
F.clamp_bounding_boxes,
kernels={datapoints.BoundingBoxes: F.clamp_bounding_boxes},
kernels={tv_tensors.BoundingBoxes: F.clamp_bounding_boxes},
test_marks=[
skip_dispatch_datapoint,
skip_dispatch_tv_tensor,
],
),
DispatcherInfo(
F.convert_bounding_box_format,
kernels={datapoints.BoundingBoxes: F.convert_bounding_box_format},
kernels={tv_tensors.BoundingBoxes: F.convert_bounding_box_format},
test_marks=[
skip_dispatch_datapoint,
skip_dispatch_tv_tensor,
],
),
]
......@@ -7,7 +7,7 @@ import pytest
import torch.testing
import torchvision.ops
import torchvision.transforms.v2.functional as F
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value, _parse_pad_padding
from transforms_v2_legacy_utils import (
ArgsKwargs,
......@@ -193,7 +193,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
bbox_xyxy = F.convert_bounding_box_format(
bbox.as_subclass(torch.Tensor),
old_format=format_,
new_format=datapoints.BoundingBoxFormat.XYXY,
new_format=tv_tensors.BoundingBoxFormat.XYXY,
inplace=True,
)
points = np.array(
......@@ -215,7 +215,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
dtype=bbox_xyxy.dtype,
)
out_bbox = F.convert_bounding_box_format(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
out_bbox, old_format=tv_tensors.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
)
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox = F.clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_)
......@@ -228,7 +228,7 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
def sample_inputs_convert_bounding_box_format():
formats = list(datapoints.BoundingBoxFormat)
formats = list(tv_tensors.BoundingBoxFormat)
for bounding_boxes_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_boxes_loader, old_format=bounding_boxes_loader.format, new_format=new_format)
......@@ -659,7 +659,7 @@ def sample_inputs_perspective_bounding_boxes():
coefficients=_PERSPECTIVE_COEFFS[0],
)
format = datapoints.BoundingBoxFormat.XYXY
format = tv_tensors.BoundingBoxFormat.XYXY
loader = make_bounding_box_loader(format=format)
yield ArgsKwargs(
loader, format=format, canvas_size=loader.canvas_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
......
......@@ -27,7 +27,7 @@ import PIL.Image
import pytest
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_dtype_image, to_image, to_pil_image
......@@ -82,7 +82,7 @@ def make_image(
if color_space in {"GRAY_ALPHA", "RGBA"}:
data[..., -1, :, :] = max_value
return datapoints.Image(data)
return tv_tensors.Image(data)
def make_image_tensor(*args, **kwargs):
......@@ -96,7 +96,7 @@ def make_image_pil(*args, **kwargs):
def make_bounding_boxes(
canvas_size=DEFAULT_SIZE,
*,
format=datapoints.BoundingBoxFormat.XYXY,
format=tv_tensors.BoundingBoxFormat.XYXY,
batch_dims=(),
dtype=None,
device="cpu",
......@@ -107,12 +107,12 @@ def make_bounding_boxes(
return torch.stack([torch.randint(max_value - v, ()) for v in values.flatten().tolist()]).reshape(values.shape)
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]
format = tv_tensors.BoundingBoxFormat[format]
dtype = dtype or torch.float32
if any(dim == 0 for dim in batch_dims):
return datapoints.BoundingBoxes(
return tv_tensors.BoundingBoxes(
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size
)
......@@ -120,28 +120,28 @@ def make_bounding_boxes(
y = sample_position(h, canvas_size[0])
x = sample_position(w, canvas_size[1])
if format is datapoints.BoundingBoxFormat.XYWH:
if format is tv_tensors.BoundingBoxFormat.XYWH:
parts = (x, y, w, h)
elif format is datapoints.BoundingBoxFormat.XYXY:
elif format is tv_tensors.BoundingBoxFormat.XYXY:
x1, y1 = x, y
x2 = x1 + w
y2 = y1 + h
parts = (x1, y1, x2, y2)
elif format is datapoints.BoundingBoxFormat.CXCYWH:
elif format is tv_tensors.BoundingBoxFormat.CXCYWH:
cx = x + w / 2
cy = y + h / 2
parts = (cx, cy, w, h)
else:
raise ValueError(f"Format {format} is not supported")
return datapoints.BoundingBoxes(
return tv_tensors.BoundingBoxes(
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
)
def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtype=None, device="cpu"):
"""Make a "detection" mask, i.e. (*, N, H, W), where each object is encoded as one of N boolean masks"""
return datapoints.Mask(
return tv_tensors.Mask(
torch.testing.make_tensor(
(*batch_dims, num_objects, *size),
low=0,
......@@ -154,7 +154,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(), dtype=None, device="cpu"):
"""Make a "segmentation" mask, i.e. (*, H, W), where the category is encoded as pixel value"""
return datapoints.Mask(
return tv_tensors.Mask(
torch.testing.make_tensor(
(*batch_dims, *size),
low=0,
......@@ -166,7 +166,7 @@ def make_segmentation_mask(size=DEFAULT_SIZE, *, num_categories=10, batch_dims=(
def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
return tv_tensors.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
def make_video_tensor(*args, **kwargs):
......@@ -335,7 +335,7 @@ def make_image_loader_for_interpolation(
image_tensor = image_tensor.to(device=device)
image_tensor = to_dtype_image(image_tensor, dtype=dtype, scale=True)
return datapoints.Image(image_tensor)
return tv_tensors.Image(image_tensor)
return ImageLoader(fn, shape=(num_channels, *size), dtype=dtype, memory_format=memory_format)
......@@ -352,7 +352,7 @@ def make_image_loaders_for_interpolation(
@dataclasses.dataclass
class BoundingBoxesLoader(TensorLoader):
format: datapoints.BoundingBoxFormat
format: tv_tensors.BoundingBoxFormat
spatial_size: Tuple[int, int]
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
......@@ -362,7 +362,7 @@ class BoundingBoxesLoader(TensorLoader):
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format]
format = tv_tensors.BoundingBoxFormat[format]
spatial_size = _parse_size(spatial_size, name="spatial_size")
......@@ -381,7 +381,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
def make_bounding_box_loaders(
*,
extra_dims=tuple(d for d in DEFAULT_EXTRA_DIMS if len(d) < 2),
formats=tuple(datapoints.BoundingBoxFormat),
formats=tuple(tv_tensors.BoundingBoxFormat),
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64),
):
......
......@@ -137,7 +137,7 @@ __all__ = (
# Ref: https://peps.python.org/pep-0562/
def __getattr__(name):
if name in ("wrap_dataset_for_transforms_v2",):
from torchvision.datapoints._dataset_wrapper import wrap_dataset_for_transforms_v2
from torchvision.tv_tensors._dataset_wrapper import wrap_dataset_for_transforms_v2
return wrap_dataset_for_transforms_v2
......
from . import datapoints, models, transforms, utils
from . import models, transforms, tv_tensors, utils
......@@ -6,8 +6,6 @@ import numpy as np
import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.datapoints import BoundingBoxes
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
......@@ -16,6 +14,8 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file,
read_mat,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
......
......@@ -4,8 +4,6 @@ from typing import Any, BinaryIO, Dict, Iterator, List, Optional, Sequence, Tupl
import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper, Zipper
from torchvision.datapoints import BoundingBoxes
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -14,6 +12,8 @@ from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
path_accessor,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
......
......@@ -6,8 +6,6 @@ from typing import Any, BinaryIO, cast, Dict, Iterator, List, Optional, Tuple, U
import numpy as np
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.datapoints import Image
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
......@@ -15,6 +13,8 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import Image
from .._api import register_dataset, register_info
......
......@@ -2,7 +2,6 @@ import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import Demultiplexer, Filter, IterDataPipe, IterKeyZipper, JsonParser, Mapper, UnBatcher
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -12,6 +11,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor,
path_comparator,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
......
......@@ -14,8 +14,6 @@ from torchdata.datapipes.iter import (
Mapper,
UnBatcher,
)
from torchvision.datapoints import BoundingBoxes, Mask
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -26,6 +24,8 @@ from torchvision.prototype.datasets.utils._internal import (
path_accessor,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes, Mask
from .._api import register_dataset, register_info
......
......@@ -2,7 +2,6 @@ import pathlib
from typing import Any, Dict, List, Tuple, Union
from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
......@@ -10,6 +9,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
......
......@@ -15,8 +15,6 @@ from torchdata.datapipes.iter import (
Mapper,
)
from torchdata.datapipes.map import IterToMapConverter
from torchvision.datapoints import BoundingBoxes
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -28,6 +26,8 @@ from torchvision.prototype.datasets.utils._internal import (
read_categories_file,
read_mat,
)
from torchvision.prototype.tv_tensors import Label
from torchvision.tv_tensors import BoundingBoxes
from .._api import register_dataset, register_info
......
......@@ -3,7 +3,6 @@ import pathlib
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from torchdata.datapipes.iter import CSVParser, Demultiplexer, Filter, IterDataPipe, IterKeyZipper, LineReader, Mapper
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
getitem,
......@@ -13,6 +12,7 @@ from torchvision.prototype.datasets.utils._internal import (
path_comparator,
read_categories_file,
)
from torchvision.prototype.tv_tensors import Label
from .._api import register_dataset, register_info
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment