Unverified Commit 3991ab99 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Promote prototype transforms to beta status (#7261)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent d010e82f
......@@ -584,11 +584,8 @@ class DatasetTestCase(unittest.TestCase):
@test_all_configs
def test_transforms_v2_wrapper(self, config):
# Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs
# to be available with the next release when v2 is released. Thus, if this import somehow fails on the release
# branch, we screwed up the roll-out
from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2
from torchvision.prototype.datapoints._datapoint import Datapoint
from torchvision.datapoints import wrap_dataset_for_transforms_v2
from torchvision.datapoints._datapoint import Datapoint
try:
with self.create_dataset(config) as (dataset, _):
......@@ -596,12 +593,13 @@ class DatasetTestCase(unittest.TestCase):
wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error:
if str(error).startswith(f"No wrapper exists for dataset class {type(dataset).__name__}"):
return
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
if str(error).startswith(msg):
pytest.skip(msg)
raise error
except RuntimeError as error:
if "currently not supported by this wrapper" in str(error):
return
pytest.skip("Config is currently not supported by this wrapper")
raise error
......
......@@ -12,12 +12,13 @@ import PIL.Image
import pytest
import torch
import torch.testing
import torchvision.prototype.datapoints as proto_datapoints
from datasets_utils import combinations_grid
from torch.nn.functional import one_hot
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision.prototype import datapoints
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision import datapoints
from torchvision.transforms.functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor
__all__ = [
"assert_close",
......@@ -457,7 +458,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64):
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return datapoints.Label(data, categories=categories)
return proto_datapoints.Label(data, categories=categories)
return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)
......@@ -481,7 +482,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int
# since `one_hot` only supports int64
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
data = one_hot(label, num_classes=num_categories).to(dtype)
return datapoints.OneHotLabel(data, categories=categories)
return proto_datapoints.OneHotLabel(data, categories=categories)
return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)
......
import collections.abc
import pytest
import torchvision.prototype.transforms.functional as F
import torchvision.transforms.v2.functional as F
from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
from torchvision.prototype import datapoints
from torchvision import datapoints
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
......
......@@ -8,7 +8,7 @@ import PIL.Image
import pytest
import torch.testing
import torchvision.ops
import torchvision.prototype.transforms.functional as F
import torchvision.transforms.v2.functional as F
from datasets_utils import combinations_grid
from prototype_common_utils import (
ArgsKwargs,
......@@ -28,7 +28,7 @@ from prototype_common_utils import (
TestMark,
)
from torch.utils._pytree import tree_map
from torchvision.prototype import datapoints
from torchvision import datapoints
from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding
__all__ = ["KernelInfo", "KERNEL_INFOS"]
......@@ -2383,19 +2383,18 @@ KERNEL_INFOS.extend(
def sample_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]):
for temporal_dim in [-4, len(video_loader.shape) - 4]:
yield ArgsKwargs(video_loader, num_samples=2, temporal_dim=temporal_dim)
yield ArgsKwargs(video_loader, num_samples=2)
def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4):
def reference_uniform_temporal_subsample_video(x, num_samples):
# Copy-pasted from
# https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
t = x.shape[temporal_dim]
t = x.shape[-4]
assert num_samples > 0 and t > 0
# Sample by nearest neighbor interpolation if num_samples > t.
indices = torch.linspace(0, t - 1, num_samples)
indices = torch.clamp(indices, 0, t - 1).long()
return torch.index_select(x, temporal_dim, indices)
return torch.index_select(x, -4, indices)
def reference_inputs_uniform_temporal_subsample_video():
......@@ -2410,12 +2409,5 @@ KERNEL_INFOS.append(
sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
reference_fn=reference_uniform_temporal_subsample_video,
reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
test_marks=[
TestMark(
("TestKernels", "test_batched_vs_single"),
pytest.mark.skip("Positive `temporal_dim` arguments are not equivalent for batched and single inputs"),
condition=lambda args_kwargs: args_kwargs.kwargs.get("temporal_dim") >= 0,
),
],
)
)
......@@ -5,8 +5,8 @@ import torch
from PIL import Image
from torchvision import datasets
from torchvision.prototype import datapoints
from torchvision import datapoints, datasets
from torchvision.prototype import datapoints as proto_datapoints
@pytest.mark.parametrize(
......@@ -24,38 +24,38 @@ from torchvision.prototype import datapoints
],
)
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
datapoint = datapoints.Label(data, requires_grad=input_requires_grad)
datapoint = proto_datapoints.Label(data, requires_grad=input_requires_grad)
assert datapoint.requires_grad is expected_requires_grad
def test_isinstance():
assert isinstance(
datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
proto_datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
torch.Tensor,
)
def test_wrapping_no_copy():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
assert label.data_ptr() == tensor.data_ptr()
def test_to_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
label_to = label.to(torch.int32)
assert type(label_to) is datapoints.Label
assert type(label_to) is proto_datapoints.Label
assert label_to.dtype is torch.int32
assert label_to.categories is label.categories
def test_to_datapoint_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
label = proto_datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
tensor_to = tensor.to(label)
......@@ -65,31 +65,31 @@ def test_to_datapoint_reference():
def test_clone_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
label_clone = label.clone()
assert type(label_clone) is datapoints.Label
assert type(label_clone) is proto_datapoints.Label
assert label_clone.data_ptr() != label.data_ptr()
assert label_clone.categories is label.categories
def test_requires_grad__wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
assert not label.requires_grad
label_requires_grad = label.requires_grad_(True)
assert type(label_requires_grad) is datapoints.Label
assert type(label_requires_grad) is proto_datapoints.Label
assert label.requires_grad
assert label_requires_grad.requires_grad
def test_other_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here
output = label * 2
......@@ -107,33 +107,33 @@ def test_other_op_no_wrapping():
)
def test_no_tensor_output_op_no_wrapping(op):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
output = op(label)
assert type(output) is not datapoints.Label
assert type(output) is not proto_datapoints.Label
def test_inplace_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
output = label.add_(0)
assert type(output) is torch.Tensor
assert type(label) is datapoints.Label
assert type(label) is proto_datapoints.Label
def test_wrap_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"])
label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here
output = label * 2
label_new = datapoints.Label.wrap_like(label, output)
label_new = proto_datapoints.Label.wrap_like(label, output)
assert type(label_new) is datapoints.Label
assert type(label_new) is proto_datapoints.Label
assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories
......
......@@ -5,8 +5,8 @@ from pathlib import Path
import pytest
import torch
import torchvision.transforms.v2 as transforms
import torchvision.prototype.transforms.utils
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair
......@@ -19,10 +19,13 @@ from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.dataloader2.graph.utils import traverse_dps
from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper
from torchvision import datapoints
from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, datasets, transforms
from torchvision.prototype import datasets
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.transforms.v2.utils import is_simple_tensor
def assert_samples_equal(*args, msg=None, **kwargs):
......@@ -141,9 +144,7 @@ class TestCommon:
dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset))
simple_tensors = {
key for key, value in sample.items() if torchvision.prototype.transforms.utils.is_simple_tensor(value)
}
simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)}
if simple_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
......@@ -276,6 +277,6 @@ class TestUSPS:
assert "label" in sample
assert isinstance(sample["image"], datapoints.Image)
assert isinstance(sample["label"], datapoints.Label)
assert isinstance(sample["label"], Label)
assert sample["image"].shape == (1, 16, 16)
......@@ -10,8 +10,11 @@ import numpy as np
import PIL.Image
import pytest
import torch
import torchvision.prototype.datapoints as proto_datapoints
import torchvision.prototype.transforms as proto_transforms
import torchvision.transforms.v2 as transforms
import torchvision.prototype.transforms.utils
import torchvision.transforms.v2.utils
from common_utils import cpu_and_gpu
from prototype_common_utils import (
assert_equal,
......@@ -28,11 +31,12 @@ from prototype_common_utils import (
make_videos,
)
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints, transforms
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor, query_chw
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import _convert_fill_arg
from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
......@@ -281,8 +285,8 @@ class TestSmoke:
],
)
for transform in [
transforms.RandomMixup(alpha=1.0),
transforms.RandomCutmix(alpha=1.0),
proto_transforms.RandomMixup(alpha=1.0),
proto_transforms.RandomCutmix(alpha=1.0),
]
]
)
......@@ -563,7 +567,7 @@ class TestPad:
def test__transform(self, padding, fill, padding_mode, mocker):
transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode)
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
inpt = mocker.MagicMock(spec=datapoints.Image)
_ = transform(inpt)
......@@ -576,7 +580,7 @@ class TestPad:
def test__transform_image_mask(self, fill, mocker):
transform = transforms.Pad(1, fill=fill, padding_mode="constant")
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
image = datapoints.Image(torch.rand(3, 32, 32))
mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]
......@@ -634,7 +638,7 @@ class TestRandomZoomOut:
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1)
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
......@@ -651,7 +655,7 @@ class TestRandomZoomOut:
def test__transform_image_mask(self, fill, mocker):
transform = transforms.RandomZoomOut(fill=fill, p=1.0)
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
image = datapoints.Image(torch.rand(3, 32, 32))
mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]
......@@ -724,7 +728,7 @@ class TestRandomRotation:
else:
assert transform.degrees == [float(-degrees), float(degrees)]
fn = mocker.patch("torchvision.prototype.transforms.functional.rotate")
fn = mocker.patch("torchvision.transforms.v2.functional.rotate")
inpt = mocker.MagicMock(spec=datapoints.Image)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
......@@ -859,7 +863,7 @@ class TestRandomAffine:
else:
assert transform.degrees == [float(-degrees), float(degrees)]
fn = mocker.patch("torchvision.prototype.transforms.functional.affine")
fn = mocker.patch("torchvision.transforms.v2.functional.affine")
inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3
inpt.spatial_size = (24, 32)
......@@ -964,8 +968,8 @@ class TestRandomCrop:
)
else:
expected.spatial_size = inpt.spatial_size
_ = mocker.patch("torchvision.prototype.transforms.functional.pad", return_value=expected)
fn_crop = mocker.patch("torchvision.prototype.transforms.functional.crop")
_ = mocker.patch("torchvision.transforms.v2.functional.pad", return_value=expected)
fn_crop = mocker.patch("torchvision.transforms.v2.functional.crop")
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
......@@ -1036,7 +1040,7 @@ class TestGaussianBlur:
else:
assert transform.sigma == [sigma, sigma]
fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur")
fn = mocker.patch("torchvision.transforms.v2.functional.gaussian_blur")
inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3
inpt.spatial_size = (24, 32)
......@@ -1068,7 +1072,7 @@ class TestRandomColorOp:
def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker):
transform = transform_cls(p=p, **kwargs)
fn = mocker.patch(f"torchvision.prototype.transforms.functional.{func_op_name}")
fn = mocker.patch(f"torchvision.transforms.v2.functional.{func_op_name}")
inpt = mocker.MagicMock(spec=datapoints.Image)
_ = transform(inpt)
if p > 0.0:
......@@ -1104,7 +1108,7 @@ class TestRandomPerspective:
fill = 12
transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation)
fn = mocker.patch("torchvision.prototype.transforms.functional.perspective")
fn = mocker.patch("torchvision.transforms.v2.functional.perspective")
inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3
inpt.spatial_size = (24, 32)
......@@ -1178,7 +1182,7 @@ class TestElasticTransform:
else:
assert transform.sigma == sigma
fn = mocker.patch("torchvision.prototype.transforms.functional.elastic")
fn = mocker.patch("torchvision.transforms.v2.functional.elastic")
inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3
inpt.spatial_size = (24, 32)
......@@ -1251,13 +1255,13 @@ class TestRandomErasing:
w_sentinel = mocker.MagicMock()
v_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._augment.RandomErasing._get_params",
"torchvision.transforms.v2._augment.RandomErasing._get_params",
return_value=dict(i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel),
)
inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.prototype.transforms._augment.F.erase")
mock = mocker.patch("torchvision.transforms.v2._augment.F.erase")
output = transform(inpt_sentinel)
if p:
......@@ -1300,7 +1304,7 @@ class TestToImageTensor:
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch(
"torchvision.prototype.transforms.functional.to_image_tensor",
"torchvision.transforms.v2.functional.to_image_tensor",
return_value=torch.rand(1, 3, 8, 8),
)
......@@ -1319,7 +1323,7 @@ class TestToImagePIL:
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil")
fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil")
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImagePIL()
......@@ -1336,7 +1340,7 @@ class TestToPILImage:
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int],
)
def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil")
fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil")
inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToPILImage()
......@@ -1443,7 +1447,7 @@ class TestRandomIoUCrop:
transform = transforms.RandomIoUCrop(sampler_options=[2.0])
image = datapoints.Image(torch.rand(1, 3, 4, 4))
bboxes = datapoints.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4))
label = datapoints.Label(torch.tensor([1]))
label = proto_datapoints.Label(torch.tensor([1]))
sample = [image, bboxes, label]
# Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock(return_value={})
......@@ -1454,7 +1458,7 @@ class TestRandomIoUCrop:
transform = transforms.RandomIoUCrop()
with pytest.raises(
TypeError,
match="requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels",
match="requires input sample to contain tensor or PIL images and bounding boxes",
):
transform(torch.tensor(0))
......@@ -1463,13 +1467,11 @@ class TestRandomIoUCrop:
image = datapoints.Image(torch.rand(3, 32, 24))
bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), extra_dims=(6,))
label = datapoints.Label(torch.randint(0, 10, size=(6,)))
ohe_label = datapoints.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
masks = make_detection_mask((32, 24), num_objects=6)
sample = [image, bboxes, label, ohe_label, masks]
sample = [image, bboxes, masks]
fn = mocker.patch("torchvision.prototype.transforms.functional.crop", side_effect=lambda x, **params: x)
fn = mocker.patch("torchvision.transforms.v2.functional.crop", side_effect=lambda x, **params: x)
is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool)
params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area)
......@@ -1493,17 +1495,7 @@ class TestRandomIoUCrop:
assert isinstance(output_bboxes, datapoints.BoundingBox)
assert len(output_bboxes) == expected_within_targets
# check labels
output_label = output[2]
assert isinstance(output_label, datapoints.Label)
assert len(output_label) == expected_within_targets
torch.testing.assert_close(output_label, label[is_within_crop_area])
output_ohe_label = output[3]
assert isinstance(output_ohe_label, datapoints.OneHotLabel)
torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area])
output_masks = output[4]
output_masks = output[2]
assert isinstance(output_masks, datapoints.Mask)
assert len(output_masks) == expected_within_targets
......@@ -1545,12 +1537,12 @@ class TestScaleJitter:
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel)
"torchvision.transforms.v2._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel)
)
inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel)
mock.assert_called_once_with(
......@@ -1592,13 +1584,13 @@ class TestRandomShortestSize:
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.RandomShortestSize._get_params",
"torchvision.transforms.v2._geometry.RandomShortestSize._get_params",
return_value=dict(size=size_sentinel),
)
inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel)
mock.assert_called_once_with(
......@@ -1613,13 +1605,13 @@ class TestSimpleCopyPaste:
return mocker.MagicMock(spec=image_type)
def test__extract_image_targets_assertion(self, mocker):
transform = transforms.SimpleCopyPaste()
transform = proto_transforms.SimpleCopyPaste()
flat_sample = [
# images, batch size = 2
self.create_fake_image(mocker, datapoints.Image),
# labels, bboxes, masks
mocker.MagicMock(spec=datapoints.Label),
mocker.MagicMock(spec=proto_datapoints.Label),
mocker.MagicMock(spec=datapoints.BoundingBox),
mocker.MagicMock(spec=datapoints.Mask),
# labels, bboxes, masks
......@@ -1631,9 +1623,9 @@ class TestSimpleCopyPaste:
transform._extract_image_targets(flat_sample)
@pytest.mark.parametrize("image_type", [datapoints.Image, PIL.Image.Image, torch.Tensor])
@pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel])
@pytest.mark.parametrize("label_type", [proto_datapoints.Label, proto_datapoints.OneHotLabel])
def test__extract_image_targets(self, image_type, label_type, mocker):
transform = transforms.SimpleCopyPaste()
transform = proto_transforms.SimpleCopyPaste()
flat_sample = [
# images, batch size = 2
......@@ -1669,7 +1661,7 @@ class TestSimpleCopyPaste:
assert isinstance(target[key], type_)
assert target[key] in flat_sample
@pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel])
@pytest.mark.parametrize("label_type", [proto_datapoints.Label, proto_datapoints.OneHotLabel])
def test__copy_paste(self, label_type):
image = 2 * torch.ones(3, 32, 32)
masks = torch.zeros(2, 32, 32)
......@@ -1679,7 +1671,7 @@ class TestSimpleCopyPaste:
blending = True
resize_interpolation = InterpolationMode.BILINEAR
antialias = None
if label_type == datapoints.OneHotLabel:
if label_type == proto_datapoints.OneHotLabel:
labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = {
"boxes": datapoints.BoundingBox(
......@@ -1694,7 +1686,7 @@ class TestSimpleCopyPaste:
paste_masks[0, 13:19, 12:18] = 1
paste_masks[1, 15:19, 1:8] = 1
paste_labels = torch.tensor([3, 4])
if label_type == datapoints.OneHotLabel:
if label_type == proto_datapoints.OneHotLabel:
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = {
"boxes": datapoints.BoundingBox(
......@@ -1704,7 +1696,7 @@ class TestSimpleCopyPaste:
"labels": label_type(paste_labels),
}
transform = transforms.SimpleCopyPaste()
transform = proto_transforms.SimpleCopyPaste()
random_selection = torch.tensor([0, 1])
output_image, output_target = transform._copy_paste(
image, target, paste_image, paste_target, random_selection, blending, resize_interpolation, antialias
......@@ -1716,7 +1708,7 @@ class TestSimpleCopyPaste:
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
expected_labels = torch.tensor([1, 2, 3, 4])
if label_type == datapoints.OneHotLabel:
if label_type == proto_datapoints.OneHotLabel:
expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5)
torch.testing.assert_close(output_target["labels"], label_type(expected_labels))
......@@ -1731,7 +1723,7 @@ class TestFixedSizeCrop:
batch_shape = (10,)
spatial_size = (11, 5)
transform = transforms.FixedSizeCrop(size=crop_size)
transform = proto_transforms.FixedSizeCrop(size=crop_size)
flat_inputs = [
make_image(size=spatial_size, color_space="RGB"),
......@@ -1759,9 +1751,8 @@ class TestFixedSizeCrop:
fill_sentinel = 12
padding_mode_sentinel = mocker.MagicMock()
transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
transform = proto_transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
transform._transformed_types = (mocker.MagicMock,)
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
needs_crop, needs_pad = needs
......@@ -1810,7 +1801,7 @@ class TestFixedSizeCrop:
if not needs_crop:
assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel
fill_sentinel = transforms._utils._convert_fill_arg(fill_sentinel)
fill_sentinel = _convert_fill_arg(fill_sentinel)
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else:
mock_pad.assert_not_called()
......@@ -1839,8 +1830,7 @@ class TestFixedSizeCrop:
masks = make_detection_mask(size=spatial_size, extra_dims=(batch_size,))
labels = make_label(extra_dims=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
transform = proto_transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
output = transform(
......@@ -1877,8 +1867,7 @@ class TestFixedSizeCrop:
)
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box")
transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
transform = proto_transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
transform(bounding_box)
......@@ -1922,10 +1911,10 @@ class TestLinearTransformation:
class TestLabelToOneHot:
def test__transform(self):
categories = ["apple", "pear", "pineapple"]
labels = datapoints.Label(torch.tensor([0, 1, 2, 1]), categories=categories)
transform = transforms.LabelToOneHot()
labels = proto_datapoints.Label(torch.tensor([0, 1, 2, 1]), categories=categories)
transform = proto_transforms.LabelToOneHot()
ohe_labels = transform(labels)
assert isinstance(ohe_labels, datapoints.OneHotLabel)
assert isinstance(ohe_labels, proto_datapoints.OneHotLabel)
assert ohe_labels.shape == (4, 3)
assert ohe_labels.categories == labels.categories == categories
......@@ -1956,13 +1945,13 @@ class TestRandomResize:
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.RandomResize._get_params",
"torchvision.transforms.v2._geometry.RandomResize._get_params",
return_value=dict(size=size_sentinel),
)
inpt_sentinel = mocker.MagicMock()
mock_resize = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
mock_resize = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel)
mock_resize.assert_called_with(
......@@ -2048,7 +2037,7 @@ class TestPermuteDimensions:
int=0,
)
transform = transforms.PermuteDimensions(dims)
transform = proto_transforms.PermuteDimensions(dims)
transformed_sample = transform(sample)
for key, value in sample.items():
......@@ -2056,7 +2045,7 @@ class TestPermuteDimensions:
transformed_value = transformed_sample[key]
if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
value, (datapoints.Image, torchvision.transforms.v2.utils.is_simple_tensor, datapoints.Video)
):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
......@@ -2067,14 +2056,14 @@ class TestPermuteDimensions:
@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4))
transform = transforms.PermuteDimensions(dims=(1, 2, 0))
transform = proto_transforms.PermuteDimensions(dims=(1, 2, 0))
assert transform(tensor).shape == (3, 4, 2)
@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
proto_transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
class TestTransposeDimensions:
......@@ -2094,7 +2083,7 @@ class TestTransposeDimensions:
int=0,
)
transform = transforms.TransposeDimensions(dims)
transform = proto_transforms.TransposeDimensions(dims)
transformed_sample = transform(sample)
for key, value in sample.items():
......@@ -2103,7 +2092,7 @@ class TestTransposeDimensions:
transposed_dims = transform.dims.get(value_type)
if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
value, (datapoints.Image, torchvision.transforms.v2.utils.is_simple_tensor, datapoints.Video)
):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
......@@ -2114,14 +2103,14 @@ class TestTransposeDimensions:
@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4))
transform = transforms.TransposeDimensions(dims=(0, 2))
transform = proto_transforms.TransposeDimensions(dims=(0, 2))
assert transform(tensor).shape == (4, 3, 2)
@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
proto_transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
class TestUniformTemporalSubsample:
......
......@@ -12,6 +12,8 @@ import PIL.Image
import pytest
import torch
import torchvision.prototype.transforms as prototype_transforms
import torchvision.transforms.v2 as v2_transforms
from prototype_common_utils import (
ArgsKwargs,
assert_close,
......@@ -24,13 +26,13 @@ from prototype_common_utils import (
make_segmentation_mask,
)
from torch import nn
from torchvision import transforms as legacy_transforms
from torchvision import datapoints, transforms as legacy_transforms
from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, transforms as prototype_transforms
from torchvision.prototype.transforms import functional as prototype_F
from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import query_spatial_size
from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import query_spatial_size
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
......@@ -71,7 +73,7 @@ LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] *
CONSISTENCY_CONFIGS = [
ConsistencyConfig(
prototype_transforms.Normalize,
v2_transforms.Normalize,
legacy_transforms.Normalize,
[
ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
......@@ -80,14 +82,14 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
),
ConsistencyConfig(
prototype_transforms.Resize,
v2_transforms.Resize,
legacy_transforms.Resize,
[
NotScriptableArgsKwargs(32),
ArgsKwargs([32]),
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC),
......@@ -100,7 +102,7 @@ CONSISTENCY_CONFIGS = [
],
),
ConsistencyConfig(
prototype_transforms.CenterCrop,
v2_transforms.CenterCrop,
legacy_transforms.CenterCrop,
[
ArgsKwargs(18),
......@@ -108,7 +110,7 @@ CONSISTENCY_CONFIGS = [
],
),
ConsistencyConfig(
prototype_transforms.FiveCrop,
v2_transforms.FiveCrop,
legacy_transforms.FiveCrop,
[
ArgsKwargs(18),
......@@ -117,7 +119,7 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
),
ConsistencyConfig(
prototype_transforms.TenCrop,
v2_transforms.TenCrop,
legacy_transforms.TenCrop,
[
ArgsKwargs(18),
......@@ -127,7 +129,7 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
),
ConsistencyConfig(
prototype_transforms.Pad,
v2_transforms.Pad,
legacy_transforms.Pad,
[
NotScriptableArgsKwargs(3),
......@@ -143,7 +145,7 @@ CONSISTENCY_CONFIGS = [
),
*[
ConsistencyConfig(
prototype_transforms.LinearTransformation,
v2_transforms.LinearTransformation,
legacy_transforms.LinearTransformation,
[
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
......@@ -164,7 +166,7 @@ CONSISTENCY_CONFIGS = [
]
],
ConsistencyConfig(
prototype_transforms.Grayscale,
v2_transforms.Grayscale,
legacy_transforms.Grayscale,
[
ArgsKwargs(num_output_channels=1),
......@@ -175,7 +177,7 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
prototype_transforms.ConvertDtype,
v2_transforms.ConvertDtype,
legacy_transforms.ConvertImageDtype,
[
ArgsKwargs(torch.float16),
......@@ -189,7 +191,7 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
prototype_transforms.ToPILImage,
v2_transforms.ToPILImage,
legacy_transforms.ToPILImage,
[NotScriptableArgsKwargs()],
make_images_kwargs=dict(
......@@ -204,7 +206,7 @@ CONSISTENCY_CONFIGS = [
supports_pil=False,
),
ConsistencyConfig(
prototype_transforms.Lambda,
v2_transforms.Lambda,
legacy_transforms.Lambda,
[
NotScriptableArgsKwargs(lambda image: image / 2),
......@@ -214,7 +216,7 @@ CONSISTENCY_CONFIGS = [
supports_pil=False,
),
ConsistencyConfig(
prototype_transforms.RandomHorizontalFlip,
v2_transforms.RandomHorizontalFlip,
legacy_transforms.RandomHorizontalFlip,
[
ArgsKwargs(p=0),
......@@ -222,7 +224,7 @@ CONSISTENCY_CONFIGS = [
],
),
ConsistencyConfig(
prototype_transforms.RandomVerticalFlip,
v2_transforms.RandomVerticalFlip,
legacy_transforms.RandomVerticalFlip,
[
ArgsKwargs(p=0),
......@@ -230,7 +232,7 @@ CONSISTENCY_CONFIGS = [
],
),
ConsistencyConfig(
prototype_transforms.RandomEqualize,
v2_transforms.RandomEqualize,
legacy_transforms.RandomEqualize,
[
ArgsKwargs(p=0),
......@@ -239,7 +241,7 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
),
ConsistencyConfig(
prototype_transforms.RandomInvert,
v2_transforms.RandomInvert,
legacy_transforms.RandomInvert,
[
ArgsKwargs(p=0),
......@@ -247,7 +249,7 @@ CONSISTENCY_CONFIGS = [
],
),
ConsistencyConfig(
prototype_transforms.RandomPosterize,
v2_transforms.RandomPosterize,
legacy_transforms.RandomPosterize,
[
ArgsKwargs(p=0, bits=5),
......@@ -257,7 +259,7 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
),
ConsistencyConfig(
prototype_transforms.RandomSolarize,
v2_transforms.RandomSolarize,
legacy_transforms.RandomSolarize,
[
ArgsKwargs(p=0, threshold=0.5),
......@@ -267,7 +269,7 @@ CONSISTENCY_CONFIGS = [
),
*[
ConsistencyConfig(
prototype_transforms.RandomAutocontrast,
v2_transforms.RandomAutocontrast,
legacy_transforms.RandomAutocontrast,
[
ArgsKwargs(p=0),
......@@ -279,7 +281,7 @@ CONSISTENCY_CONFIGS = [
for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
],
ConsistencyConfig(
prototype_transforms.RandomAdjustSharpness,
v2_transforms.RandomAdjustSharpness,
legacy_transforms.RandomAdjustSharpness,
[
ArgsKwargs(p=0, sharpness_factor=0.5),
......@@ -289,7 +291,7 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
),
ConsistencyConfig(
prototype_transforms.RandomGrayscale,
v2_transforms.RandomGrayscale,
legacy_transforms.RandomGrayscale,
[
ArgsKwargs(p=0),
......@@ -300,14 +302,14 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
prototype_transforms.RandomResizedCrop,
v2_transforms.RandomResizedCrop,
legacy_transforms.RandomResizedCrop,
[
ArgsKwargs(16),
ArgsKwargs(17, scale=(0.3, 0.7)),
ArgsKwargs(25, ratio=(0.5, 1.5)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC),
ArgsKwargs((29, 32), antialias=False),
......@@ -315,7 +317,7 @@ CONSISTENCY_CONFIGS = [
],
),
ConsistencyConfig(
prototype_transforms.RandomErasing,
v2_transforms.RandomErasing,
legacy_transforms.RandomErasing,
[
ArgsKwargs(p=0),
......@@ -329,7 +331,7 @@ CONSISTENCY_CONFIGS = [
supports_pil=False,
),
ConsistencyConfig(
prototype_transforms.ColorJitter,
v2_transforms.ColorJitter,
legacy_transforms.ColorJitter,
[
ArgsKwargs(),
......@@ -347,7 +349,7 @@ CONSISTENCY_CONFIGS = [
),
*[
ConsistencyConfig(
prototype_transforms.ElasticTransform,
v2_transforms.ElasticTransform,
legacy_transforms.ElasticTransform,
[
ArgsKwargs(),
......@@ -355,8 +357,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(alpha=(15.3, 27.2)),
ArgsKwargs(sigma=3.0),
ArgsKwargs(sigma=(2.5, 3.9)),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs(interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs(interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs(interpolation=PIL.Image.NEAREST),
ArgsKwargs(interpolation=PIL.Image.BICUBIC),
ArgsKwargs(fill=1),
......@@ -370,7 +372,7 @@ CONSISTENCY_CONFIGS = [
for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})]
],
ConsistencyConfig(
prototype_transforms.GaussianBlur,
v2_transforms.GaussianBlur,
legacy_transforms.GaussianBlur,
[
ArgsKwargs(kernel_size=3),
......@@ -381,7 +383,7 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
),
ConsistencyConfig(
prototype_transforms.RandomAffine,
v2_transforms.RandomAffine,
legacy_transforms.RandomAffine,
[
ArgsKwargs(degrees=30.0),
......@@ -392,7 +394,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(degrees=0.0, shear=(8, 17)),
ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
......@@ -401,7 +403,7 @@ CONSISTENCY_CONFIGS = [
removed_params=["fillcolor", "resample"],
),
ConsistencyConfig(
prototype_transforms.RandomCrop,
v2_transforms.RandomCrop,
legacy_transforms.RandomCrop,
[
ArgsKwargs(12),
......@@ -421,13 +423,13 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(26, 26), (18, 33), (29, 22)]),
),
ConsistencyConfig(
prototype_transforms.RandomPerspective,
v2_transforms.RandomPerspective,
legacy_transforms.RandomPerspective,
[
ArgsKwargs(p=0),
ArgsKwargs(p=1),
ArgsKwargs(p=1, distortion_scale=0.3),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
......@@ -435,12 +437,12 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs={"atol": None, "rtol": None},
),
ConsistencyConfig(
prototype_transforms.RandomRotation,
v2_transforms.RandomRotation,
legacy_transforms.RandomRotation,
[
ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR),
ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.BILINEAR),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
ArgsKwargs(degrees=30.0, expand=True),
ArgsKwargs(degrees=30.0, center=(0, 0)),
......@@ -450,43 +452,43 @@ CONSISTENCY_CONFIGS = [
removed_params=["resample"],
),
ConsistencyConfig(
prototype_transforms.PILToTensor,
v2_transforms.PILToTensor,
legacy_transforms.PILToTensor,
),
ConsistencyConfig(
prototype_transforms.ToTensor,
v2_transforms.ToTensor,
legacy_transforms.ToTensor,
),
ConsistencyConfig(
prototype_transforms.Compose,
v2_transforms.Compose,
legacy_transforms.Compose,
),
ConsistencyConfig(
prototype_transforms.RandomApply,
v2_transforms.RandomApply,
legacy_transforms.RandomApply,
),
ConsistencyConfig(
prototype_transforms.RandomChoice,
v2_transforms.RandomChoice,
legacy_transforms.RandomChoice,
),
ConsistencyConfig(
prototype_transforms.RandomOrder,
v2_transforms.RandomOrder,
legacy_transforms.RandomOrder,
),
ConsistencyConfig(
prototype_transforms.AugMix,
v2_transforms.AugMix,
legacy_transforms.AugMix,
),
ConsistencyConfig(
prototype_transforms.AutoAugment,
v2_transforms.AutoAugment,
legacy_transforms.AutoAugment,
),
ConsistencyConfig(
prototype_transforms.RandAugment,
v2_transforms.RandAugment,
legacy_transforms.RandAugment,
),
ConsistencyConfig(
prototype_transforms.TrivialAugmentWide,
v2_transforms.TrivialAugmentWide,
legacy_transforms.TrivialAugmentWide,
),
]
......@@ -680,19 +682,19 @@ get_params_parametrization = pytest.mark.parametrize(
id=transform_cls.__name__,
)
for transform_cls, get_params_args_kwargs in [
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
(v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
(v2_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
(v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(v2_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
(v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
(
prototype_transforms.RandomAffine,
v2_transforms.RandomAffine,
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
),
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
(prototype_transforms.AutoAugment, ArgsKwargs(5)),
(v2_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
(v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
(v2_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
(v2_transforms.AutoAugment, ArgsKwargs(5)),
]
],
)
......@@ -767,10 +769,10 @@ class TestContainerTransforms:
"""
def test_compose(self):
prototype_transform = prototype_transforms.Compose(
prototype_transform = v2_transforms.Compose(
[
prototype_transforms.Resize(256),
prototype_transforms.CenterCrop(224),
v2_transforms.Resize(256),
v2_transforms.CenterCrop(224),
]
)
legacy_transform = legacy_transforms.Compose(
......@@ -785,11 +787,11 @@ class TestContainerTransforms:
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type):
prototype_transform = prototype_transforms.RandomApply(
prototype_transform = v2_transforms.RandomApply(
sequence_type(
[
prototype_transforms.Resize(256),
prototype_transforms.CenterCrop(224),
v2_transforms.Resize(256),
v2_transforms.CenterCrop(224),
]
),
p=p,
......@@ -814,9 +816,9 @@ class TestContainerTransforms:
# We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
def test_random_choice(self, probabilities):
prototype_transform = prototype_transforms.RandomChoice(
prototype_transform = v2_transforms.RandomChoice(
[
prototype_transforms.Resize(256),
v2_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
],
probabilities=probabilities,
......@@ -834,7 +836,7 @@ class TestContainerTransforms:
class TestToTensorTransforms:
def test_pil_to_tensor(self):
prototype_transform = prototype_transforms.PILToTensor()
prototype_transform = v2_transforms.PILToTensor()
legacy_transform = legacy_transforms.PILToTensor()
for image in make_images(extra_dims=[()]):
......@@ -844,7 +846,7 @@ class TestToTensorTransforms:
def test_to_tensor(self):
with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
prototype_transform = prototype_transforms.ToTensor()
prototype_transform = v2_transforms.ToTensor()
legacy_transform = legacy_transforms.ToTensor()
for image in make_images(extra_dims=[()]):
......@@ -867,14 +869,14 @@ class TestAATransforms:
@pytest.mark.parametrize(
"interpolation",
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_randaug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = prototype_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
......@@ -909,14 +911,14 @@ class TestAATransforms:
@pytest.mark.parametrize(
"interpolation",
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_trivial_aug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
t = prototype_transforms.TrivialAugmentWide(interpolation=interpolation)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys())
......@@ -961,15 +963,15 @@ class TestAATransforms:
@pytest.mark.parametrize(
"interpolation",
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_augmix(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
t = prototype_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t._sample_dirichlet = lambda t: t.softmax(dim=-1)
le = len(t._AUGMENTATION_SPACE)
......@@ -1014,15 +1016,15 @@ class TestAATransforms:
@pytest.mark.parametrize(
"interpolation",
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
v2_transforms.InterpolationMode.NEAREST,
v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_aa(self, inpt, interpolation):
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = prototype_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
torch.manual_seed(12)
expected_output = t_ref(inpt)
......@@ -1087,10 +1089,16 @@ class TestRefDetTransforms:
@pytest.mark.parametrize(
"t_ref, t, data_kwargs",
[
(det_transforms.RandomHorizontalFlip(p=1.0), prototype_transforms.RandomHorizontalFlip(p=1.0), {}),
(det_transforms.RandomIoUCrop(), prototype_transforms.RandomIoUCrop(), {"with_mask": False}),
(det_transforms.RandomZoomOut(), prototype_transforms.RandomZoomOut(), {"with_mask": False}),
(det_transforms.ScaleJitter((1024, 1024)), prototype_transforms.ScaleJitter((1024, 1024)), {}),
(det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
# FIXME: make
# v2_transforms.Compose([
# v2_transforms.RandomIoUCrop(),
# v2_transforms.SanitizeBoundingBoxes()
# ])
# work
# (det_transforms.RandomIoUCrop(), v2_transforms.RandomIoUCrop(), {"with_mask": False}),
(det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
(det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024)), {}),
(
det_transforms.FixedSizeCrop((1024, 1024), fill=0),
prototype_transforms.FixedSizeCrop((1024, 1024), fill=0),
......@@ -1100,7 +1108,7 @@ class TestRefDetTransforms:
det_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
prototype_transforms.RandomShortestSize(
v2_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
{},
......@@ -1127,11 +1135,11 @@ seg_transforms = import_transforms_from_references("segmentation")
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
# counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
class PadIfSmaller(prototype_transforms.Transform):
class PadIfSmaller(v2_transforms.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = prototype_transforms._geometry._setup_fill_arg(fill)
self.fill = v2_transforms._geometry._setup_fill_arg(fill)
def _get_params(self, sample):
height, width = query_spatial_size(sample)
......@@ -1193,27 +1201,27 @@ class TestRefSegTransforms:
[
(
seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
prototype_transforms.RandomHorizontalFlip(p=1.0),
v2_transforms.RandomHorizontalFlip(p=1.0),
dict(),
),
(
seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
prototype_transforms.RandomHorizontalFlip(p=0.0),
v2_transforms.RandomHorizontalFlip(p=0.0),
dict(),
),
(
seg_transforms.RandomCrop(size=480),
prototype_transforms.Compose(
v2_transforms.Compose(
[
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})),
prototype_transforms.RandomCrop(size=480),
v2_transforms.RandomCrop(size=480),
]
),
dict(),
),
(
seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
dict(supports_pil=False, image_dtype=torch.float),
),
],
......@@ -1222,7 +1230,7 @@ class TestRefSegTransforms:
self.check(t, t_ref, data_kwargs)
def check_resize(self, mocker, t_ref, t):
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize")
mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
mock_ref = mocker.patch("torchvision.transforms.functional.resize")
for dp, dp_ref in self.make_datapoints():
......@@ -1263,9 +1271,9 @@ class TestRefSegTransforms:
# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
# normally
t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
t = v2_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
mocker.patch(
"torchvision.prototype.transforms._geometry.torch.randint",
"torchvision.transforms.v2._geometry.torch.randint",
new=patched_randint,
)
......@@ -1277,7 +1285,7 @@ class TestRefSegTransforms:
torch.manual_seed(0)
base_size = 520
t = prototype_transforms.Resize(size=base_size, antialias=True)
t = v2_transforms.Resize(size=base_size, antialias=True)
t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size)
......
......@@ -11,7 +11,6 @@ import pytest
import torch
import torchvision.prototype.transforms.utils
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import (
assert_close,
......@@ -22,11 +21,12 @@ from prototype_common_utils import (
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
from torchvision.prototype.transforms.functional._meta import clamp_bounding_box, convert_format_bounding_box
from torchvision import datapoints
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_box, convert_format_bounding_box
from torchvision.transforms.v2.utils import is_simple_tensor
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
......@@ -168,11 +168,7 @@ class TestKernels:
def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device)
datapoint_type = (
datapoints.Image
if torchvision.prototype.transforms.utils.is_simple_tensor(batched_input)
else type(batched_input)
)
datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
data_dims = {
......
......@@ -3,12 +3,12 @@ import pytest
import torch
import torchvision.prototype.transforms.utils
import torchvision.transforms.v2.utils
from prototype_common_utils import make_bounding_box, make_detection_mask, make_image
from torchvision.prototype import datapoints
from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import has_all, has_any
from torchvision import datapoints
from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import has_all, has_any
IMAGE = make_image(color_space="RGB")
......@@ -37,15 +37,15 @@ MASK = make_detection_mask(size=IMAGE.spatial_size)
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True),
(
(torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
True,
),
(
(to_image_pil(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor),
(datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
True,
),
],
......
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
......@@ -105,7 +105,7 @@ class Datapoint(torch.Tensor):
# the class. This approach avoids the DataLoader issue described at
# https://github.com/pytorch/vision/pull/6476#discussion_r953588621
if Datapoint.__F is None:
from ..transforms import functional
from ..transforms.v2 import functional
Datapoint.__F = functional
return Datapoint.__F
......
......@@ -8,9 +8,8 @@ from collections import defaultdict
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F
from torchvision import datapoints, datasets
from torchvision.transforms.v2 import functional as F
__all__ = ["wrap_dataset_for_transforms_v2"]
......
......@@ -24,7 +24,7 @@ class Image(Datapoint):
requires_grad: Optional[bool] = None,
) -> Image:
if isinstance(data, PIL.Image.Image):
from torchvision.prototype.transforms import functional as F
from torchvision.transforms.v2 import functional as F
data = F.pil_to_tensor(data)
......
......@@ -23,7 +23,7 @@ class Mask(Datapoint):
requires_grad: Optional[bool] = None,
) -> Mask:
if isinstance(data, PIL.Image.Image):
from torchvision.prototype.transforms import functional as F
from torchvision.transforms.v2 import functional as F
data = F.pil_to_tensor(data)
......
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._label import Label, OneHotLabel
from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
......@@ -5,7 +5,7 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch
from torch.utils._pytree import tree_map
from ._datapoint import Datapoint
from torchvision.datapoints._datapoint import Datapoint
L = TypeVar("L", bound="_LabelBase")
......
......@@ -6,7 +6,8 @@ import numpy as np
import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import BoundingBox, Label
from torchvision.datapoints import BoundingBox
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_sharding,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment