Unverified Commit 3991ab99 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Promote prototype transforms to beta status (#7261)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent d010e82f
...@@ -584,11 +584,8 @@ class DatasetTestCase(unittest.TestCase): ...@@ -584,11 +584,8 @@ class DatasetTestCase(unittest.TestCase):
@test_all_configs @test_all_configs
def test_transforms_v2_wrapper(self, config): def test_transforms_v2_wrapper(self, config):
# Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs from torchvision.datapoints import wrap_dataset_for_transforms_v2
# to be available with the next release when v2 is released. Thus, if this import somehow fails on the release from torchvision.datapoints._datapoint import Datapoint
# branch, we screwed up the roll-out
from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2
from torchvision.prototype.datapoints._datapoint import Datapoint
try: try:
with self.create_dataset(config) as (dataset, _): with self.create_dataset(config) as (dataset, _):
...@@ -596,12 +593,13 @@ class DatasetTestCase(unittest.TestCase): ...@@ -596,12 +593,13 @@ class DatasetTestCase(unittest.TestCase):
wrapped_sample = wrapped_dataset[0] wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
except TypeError as error: except TypeError as error:
if str(error).startswith(f"No wrapper exists for dataset class {type(dataset).__name__}"): msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
return if str(error).startswith(msg):
pytest.skip(msg)
raise error raise error
except RuntimeError as error: except RuntimeError as error:
if "currently not supported by this wrapper" in str(error): if "currently not supported by this wrapper" in str(error):
return pytest.skip("Config is currently not supported by this wrapper")
raise error raise error
......
...@@ -12,12 +12,13 @@ import PIL.Image ...@@ -12,12 +12,13 @@ import PIL.Image
import pytest import pytest
import torch import torch
import torch.testing import torch.testing
import torchvision.prototype.datapoints as proto_datapoints
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision.prototype import datapoints from torchvision import datapoints
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
from torchvision.transforms.functional_tensor import _max_value as get_max_value from torchvision.transforms.functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import convert_dtype_image_tensor, to_image_tensor
__all__ = [ __all__ = [
"assert_close", "assert_close",
...@@ -457,7 +458,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64): ...@@ -457,7 +458,7 @@ def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64):
# The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values, # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values,
# regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123 # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123
data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype) data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype)
return datapoints.Label(data, categories=categories) return proto_datapoints.Label(data, categories=categories)
return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories) return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories)
...@@ -481,7 +482,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int ...@@ -481,7 +482,7 @@ def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int
# since `one_hot` only supports int64 # since `one_hot` only supports int64
label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device) label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device)
data = one_hot(label, num_classes=num_categories).to(dtype) data = one_hot(label, num_classes=num_categories).to(dtype)
return datapoints.OneHotLabel(data, categories=categories) return proto_datapoints.OneHotLabel(data, categories=categories)
return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories) return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories)
......
import collections.abc import collections.abc
import pytest import pytest
import torchvision.prototype.transforms.functional as F import torchvision.transforms.v2.functional as F
from prototype_common_utils import InfoBase, TestMark from prototype_common_utils import InfoBase, TestMark
from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition from prototype_transforms_kernel_infos import KERNEL_INFOS, pad_xfail_jit_fill_condition
from torchvision.prototype import datapoints from torchvision import datapoints
__all__ = ["DispatcherInfo", "DISPATCHER_INFOS"] __all__ = ["DispatcherInfo", "DISPATCHER_INFOS"]
......
...@@ -8,7 +8,7 @@ import PIL.Image ...@@ -8,7 +8,7 @@ import PIL.Image
import pytest import pytest
import torch.testing import torch.testing
import torchvision.ops import torchvision.ops
import torchvision.prototype.transforms.functional as F import torchvision.transforms.v2.functional as F
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from prototype_common_utils import ( from prototype_common_utils import (
ArgsKwargs, ArgsKwargs,
...@@ -28,7 +28,7 @@ from prototype_common_utils import ( ...@@ -28,7 +28,7 @@ from prototype_common_utils import (
TestMark, TestMark,
) )
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torchvision.prototype import datapoints from torchvision import datapoints
from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding from torchvision.transforms.functional_tensor import _max_value as get_max_value, _parse_pad_padding
__all__ = ["KernelInfo", "KERNEL_INFOS"] __all__ = ["KernelInfo", "KERNEL_INFOS"]
...@@ -2383,19 +2383,18 @@ KERNEL_INFOS.extend( ...@@ -2383,19 +2383,18 @@ KERNEL_INFOS.extend(
def sample_inputs_uniform_temporal_subsample_video(): def sample_inputs_uniform_temporal_subsample_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]): for video_loader in make_video_loaders(sizes=["random"], num_frames=[4]):
for temporal_dim in [-4, len(video_loader.shape) - 4]: yield ArgsKwargs(video_loader, num_samples=2)
yield ArgsKwargs(video_loader, num_samples=2, temporal_dim=temporal_dim)
def reference_uniform_temporal_subsample_video(x, num_samples, temporal_dim=-4): def reference_uniform_temporal_subsample_video(x, num_samples):
# Copy-pasted from # Copy-pasted from
# https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19 # https://github.com/facebookresearch/pytorchvideo/blob/c8d23d8b7e597586a9e2d18f6ed31ad8aa379a7a/pytorchvideo/transforms/functional.py#L19
t = x.shape[temporal_dim] t = x.shape[-4]
assert num_samples > 0 and t > 0 assert num_samples > 0 and t > 0
# Sample by nearest neighbor interpolation if num_samples > t. # Sample by nearest neighbor interpolation if num_samples > t.
indices = torch.linspace(0, t - 1, num_samples) indices = torch.linspace(0, t - 1, num_samples)
indices = torch.clamp(indices, 0, t - 1).long() indices = torch.clamp(indices, 0, t - 1).long()
return torch.index_select(x, temporal_dim, indices) return torch.index_select(x, -4, indices)
def reference_inputs_uniform_temporal_subsample_video(): def reference_inputs_uniform_temporal_subsample_video():
...@@ -2410,12 +2409,5 @@ KERNEL_INFOS.append( ...@@ -2410,12 +2409,5 @@ KERNEL_INFOS.append(
sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video, sample_inputs_fn=sample_inputs_uniform_temporal_subsample_video,
reference_fn=reference_uniform_temporal_subsample_video, reference_fn=reference_uniform_temporal_subsample_video,
reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video, reference_inputs_fn=reference_inputs_uniform_temporal_subsample_video,
test_marks=[
TestMark(
("TestKernels", "test_batched_vs_single"),
pytest.mark.skip("Positive `temporal_dim` arguments are not equivalent for batched and single inputs"),
condition=lambda args_kwargs: args_kwargs.kwargs.get("temporal_dim") >= 0,
),
],
) )
) )
...@@ -5,8 +5,8 @@ import torch ...@@ -5,8 +5,8 @@ import torch
from PIL import Image from PIL import Image
from torchvision import datasets from torchvision import datapoints, datasets
from torchvision.prototype import datapoints from torchvision.prototype import datapoints as proto_datapoints
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -24,38 +24,38 @@ from torchvision.prototype import datapoints ...@@ -24,38 +24,38 @@ from torchvision.prototype import datapoints
], ],
) )
def test_new_requires_grad(data, input_requires_grad, expected_requires_grad): def test_new_requires_grad(data, input_requires_grad, expected_requires_grad):
datapoint = datapoints.Label(data, requires_grad=input_requires_grad) datapoint = proto_datapoints.Label(data, requires_grad=input_requires_grad)
assert datapoint.requires_grad is expected_requires_grad assert datapoint.requires_grad is expected_requires_grad
def test_isinstance(): def test_isinstance():
assert isinstance( assert isinstance(
datapoints.Label([0, 1, 0], categories=["foo", "bar"]), proto_datapoints.Label([0, 1, 0], categories=["foo", "bar"]),
torch.Tensor, torch.Tensor,
) )
def test_wrapping_no_copy(): def test_wrapping_no_copy():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
assert label.data_ptr() == tensor.data_ptr() assert label.data_ptr() == tensor.data_ptr()
def test_to_wrapping(): def test_to_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
label_to = label.to(torch.int32) label_to = label.to(torch.int32)
assert type(label_to) is datapoints.Label assert type(label_to) is proto_datapoints.Label
assert label_to.dtype is torch.int32 assert label_to.dtype is torch.int32
assert label_to.categories is label.categories assert label_to.categories is label.categories
def test_to_datapoint_reference(): def test_to_datapoint_reference():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32) label = proto_datapoints.Label(tensor, categories=["foo", "bar"]).to(torch.int32)
tensor_to = tensor.to(label) tensor_to = tensor.to(label)
...@@ -65,31 +65,31 @@ def test_to_datapoint_reference(): ...@@ -65,31 +65,31 @@ def test_to_datapoint_reference():
def test_clone_wrapping(): def test_clone_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
label_clone = label.clone() label_clone = label.clone()
assert type(label_clone) is datapoints.Label assert type(label_clone) is proto_datapoints.Label
assert label_clone.data_ptr() != label.data_ptr() assert label_clone.data_ptr() != label.data_ptr()
assert label_clone.categories is label.categories assert label_clone.categories is label.categories
def test_requires_grad__wrapping(): def test_requires_grad__wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.float32) tensor = torch.tensor([0, 1, 0], dtype=torch.float32)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
assert not label.requires_grad assert not label.requires_grad
label_requires_grad = label.requires_grad_(True) label_requires_grad = label.requires_grad_(True)
assert type(label_requires_grad) is datapoints.Label assert type(label_requires_grad) is proto_datapoints.Label
assert label.requires_grad assert label.requires_grad
assert label_requires_grad.requires_grad assert label_requires_grad.requires_grad
def test_other_op_no_wrapping(): def test_other_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here # any operation besides .to() and .clone() will do here
output = label * 2 output = label * 2
...@@ -107,33 +107,33 @@ def test_other_op_no_wrapping(): ...@@ -107,33 +107,33 @@ def test_other_op_no_wrapping():
) )
def test_no_tensor_output_op_no_wrapping(op): def test_no_tensor_output_op_no_wrapping(op):
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
output = op(label) output = op(label)
assert type(output) is not datapoints.Label assert type(output) is not proto_datapoints.Label
def test_inplace_op_no_wrapping(): def test_inplace_op_no_wrapping():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
output = label.add_(0) output = label.add_(0)
assert type(output) is torch.Tensor assert type(output) is torch.Tensor
assert type(label) is datapoints.Label assert type(label) is proto_datapoints.Label
def test_wrap_like(): def test_wrap_like():
tensor = torch.tensor([0, 1, 0], dtype=torch.int64) tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
label = datapoints.Label(tensor, categories=["foo", "bar"]) label = proto_datapoints.Label(tensor, categories=["foo", "bar"])
# any operation besides .to() and .clone() will do here # any operation besides .to() and .clone() will do here
output = label * 2 output = label * 2
label_new = datapoints.Label.wrap_like(label, output) label_new = proto_datapoints.Label.wrap_like(label, output)
assert type(label_new) is datapoints.Label assert type(label_new) is proto_datapoints.Label
assert label_new.data_ptr() == output.data_ptr() assert label_new.data_ptr() == output.data_ptr()
assert label_new.categories is label.categories assert label_new.categories is label.categories
......
...@@ -5,8 +5,8 @@ from pathlib import Path ...@@ -5,8 +5,8 @@ from pathlib import Path
import pytest import pytest
import torch import torch
import torchvision.transforms.v2 as transforms
import torchvision.prototype.transforms.utils
from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks from builtin_dataset_mocks import DATASET_MOCKS, parametrize_dataset_mocks
from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair from torch.testing._comparison import not_close_error_metas, ObjectPair, TensorLikePair
...@@ -19,10 +19,13 @@ from torch.utils.data.graph_settings import get_all_graph_pipes ...@@ -19,10 +19,13 @@ from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.dataloader2.graph.utils import traverse_dps from torchdata.dataloader2.graph.utils import traverse_dps
from torchdata.datapipes.iter import ShardingFilter, Shuffler from torchdata.datapipes.iter import ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper from torchdata.datapipes.utils import StreamWrapper
from torchvision import datapoints
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, datasets, transforms from torchvision.prototype import datasets
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import EncodedImage from torchvision.prototype.datasets.utils import EncodedImage
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.transforms.v2.utils import is_simple_tensor
def assert_samples_equal(*args, msg=None, **kwargs): def assert_samples_equal(*args, msg=None, **kwargs):
...@@ -141,9 +144,7 @@ class TestCommon: ...@@ -141,9 +144,7 @@ class TestCommon:
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
sample = next_consume(iter(dataset)) sample = next_consume(iter(dataset))
simple_tensors = { simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)}
key for key, value in sample.items() if torchvision.prototype.transforms.utils.is_simple_tensor(value)
}
if simple_tensors and not any( if simple_tensors and not any(
isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values() isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values()
...@@ -276,6 +277,6 @@ class TestUSPS: ...@@ -276,6 +277,6 @@ class TestUSPS:
assert "label" in sample assert "label" in sample
assert isinstance(sample["image"], datapoints.Image) assert isinstance(sample["image"], datapoints.Image)
assert isinstance(sample["label"], datapoints.Label) assert isinstance(sample["label"], Label)
assert sample["image"].shape == (1, 16, 16) assert sample["image"].shape == (1, 16, 16)
...@@ -10,8 +10,11 @@ import numpy as np ...@@ -10,8 +10,11 @@ import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
import torch import torch
import torchvision.prototype.datapoints as proto_datapoints
import torchvision.prototype.transforms as proto_transforms
import torchvision.transforms.v2 as transforms
import torchvision.prototype.transforms.utils import torchvision.transforms.v2.utils
from common_utils import cpu_and_gpu from common_utils import cpu_and_gpu
from prototype_common_utils import ( from prototype_common_utils import (
assert_equal, assert_equal,
...@@ -28,11 +31,12 @@ from prototype_common_utils import ( ...@@ -28,11 +31,12 @@ from prototype_common_utils import (
make_videos, make_videos,
) )
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints, transforms
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor, query_chw
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import _convert_fill_arg
from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
...@@ -281,8 +285,8 @@ class TestSmoke: ...@@ -281,8 +285,8 @@ class TestSmoke:
], ],
) )
for transform in [ for transform in [
transforms.RandomMixup(alpha=1.0), proto_transforms.RandomMixup(alpha=1.0),
transforms.RandomCutmix(alpha=1.0), proto_transforms.RandomCutmix(alpha=1.0),
] ]
] ]
) )
...@@ -563,7 +567,7 @@ class TestPad: ...@@ -563,7 +567,7 @@ class TestPad:
def test__transform(self, padding, fill, padding_mode, mocker): def test__transform(self, padding, fill, padding_mode, mocker):
transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode) transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode)
fn = mocker.patch("torchvision.prototype.transforms.functional.pad") fn = mocker.patch("torchvision.transforms.v2.functional.pad")
inpt = mocker.MagicMock(spec=datapoints.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
_ = transform(inpt) _ = transform(inpt)
...@@ -576,7 +580,7 @@ class TestPad: ...@@ -576,7 +580,7 @@ class TestPad:
def test__transform_image_mask(self, fill, mocker): def test__transform_image_mask(self, fill, mocker):
transform = transforms.Pad(1, fill=fill, padding_mode="constant") transform = transforms.Pad(1, fill=fill, padding_mode="constant")
fn = mocker.patch("torchvision.prototype.transforms.functional.pad") fn = mocker.patch("torchvision.transforms.v2.functional.pad")
image = datapoints.Image(torch.rand(3, 32, 32)) image = datapoints.Image(torch.rand(3, 32, 32))
mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32))) mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask] inpt = [image, mask]
...@@ -634,7 +638,7 @@ class TestRandomZoomOut: ...@@ -634,7 +638,7 @@ class TestRandomZoomOut:
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1)
fn = mocker.patch("torchvision.prototype.transforms.functional.pad") fn = mocker.patch("torchvision.transforms.v2.functional.pad")
# vfdev-5, Feature Request: let's store params as Transform attribute # vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users # This could be also helpful for users
# Otherwise, we can mock transform._get_params # Otherwise, we can mock transform._get_params
...@@ -651,7 +655,7 @@ class TestRandomZoomOut: ...@@ -651,7 +655,7 @@ class TestRandomZoomOut:
def test__transform_image_mask(self, fill, mocker): def test__transform_image_mask(self, fill, mocker):
transform = transforms.RandomZoomOut(fill=fill, p=1.0) transform = transforms.RandomZoomOut(fill=fill, p=1.0)
fn = mocker.patch("torchvision.prototype.transforms.functional.pad") fn = mocker.patch("torchvision.transforms.v2.functional.pad")
image = datapoints.Image(torch.rand(3, 32, 32)) image = datapoints.Image(torch.rand(3, 32, 32))
mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32))) mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask] inpt = [image, mask]
...@@ -724,7 +728,7 @@ class TestRandomRotation: ...@@ -724,7 +728,7 @@ class TestRandomRotation:
else: else:
assert transform.degrees == [float(-degrees), float(degrees)] assert transform.degrees == [float(-degrees), float(degrees)]
fn = mocker.patch("torchvision.prototype.transforms.functional.rotate") fn = mocker.patch("torchvision.transforms.v2.functional.rotate")
inpt = mocker.MagicMock(spec=datapoints.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
# vfdev-5, Feature Request: let's store params as Transform attribute # vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users # This could be also helpful for users
...@@ -859,7 +863,7 @@ class TestRandomAffine: ...@@ -859,7 +863,7 @@ class TestRandomAffine:
else: else:
assert transform.degrees == [float(-degrees), float(degrees)] assert transform.degrees == [float(-degrees), float(degrees)]
fn = mocker.patch("torchvision.prototype.transforms.functional.affine") fn = mocker.patch("torchvision.transforms.v2.functional.affine")
inpt = mocker.MagicMock(spec=datapoints.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (24, 32) inpt.spatial_size = (24, 32)
...@@ -964,8 +968,8 @@ class TestRandomCrop: ...@@ -964,8 +968,8 @@ class TestRandomCrop:
) )
else: else:
expected.spatial_size = inpt.spatial_size expected.spatial_size = inpt.spatial_size
_ = mocker.patch("torchvision.prototype.transforms.functional.pad", return_value=expected) _ = mocker.patch("torchvision.transforms.v2.functional.pad", return_value=expected)
fn_crop = mocker.patch("torchvision.prototype.transforms.functional.crop") fn_crop = mocker.patch("torchvision.transforms.v2.functional.crop")
# vfdev-5, Feature Request: let's store params as Transform attribute # vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users # This could be also helpful for users
...@@ -1036,7 +1040,7 @@ class TestGaussianBlur: ...@@ -1036,7 +1040,7 @@ class TestGaussianBlur:
else: else:
assert transform.sigma == [sigma, sigma] assert transform.sigma == [sigma, sigma]
fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur") fn = mocker.patch("torchvision.transforms.v2.functional.gaussian_blur")
inpt = mocker.MagicMock(spec=datapoints.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (24, 32) inpt.spatial_size = (24, 32)
...@@ -1068,7 +1072,7 @@ class TestRandomColorOp: ...@@ -1068,7 +1072,7 @@ class TestRandomColorOp:
def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker): def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker):
transform = transform_cls(p=p, **kwargs) transform = transform_cls(p=p, **kwargs)
fn = mocker.patch(f"torchvision.prototype.transforms.functional.{func_op_name}") fn = mocker.patch(f"torchvision.transforms.v2.functional.{func_op_name}")
inpt = mocker.MagicMock(spec=datapoints.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
_ = transform(inpt) _ = transform(inpt)
if p > 0.0: if p > 0.0:
...@@ -1104,7 +1108,7 @@ class TestRandomPerspective: ...@@ -1104,7 +1108,7 @@ class TestRandomPerspective:
fill = 12 fill = 12
transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation) transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation)
fn = mocker.patch("torchvision.prototype.transforms.functional.perspective") fn = mocker.patch("torchvision.transforms.v2.functional.perspective")
inpt = mocker.MagicMock(spec=datapoints.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (24, 32) inpt.spatial_size = (24, 32)
...@@ -1178,7 +1182,7 @@ class TestElasticTransform: ...@@ -1178,7 +1182,7 @@ class TestElasticTransform:
else: else:
assert transform.sigma == sigma assert transform.sigma == sigma
fn = mocker.patch("torchvision.prototype.transforms.functional.elastic") fn = mocker.patch("torchvision.transforms.v2.functional.elastic")
inpt = mocker.MagicMock(spec=datapoints.Image) inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3 inpt.num_channels = 3
inpt.spatial_size = (24, 32) inpt.spatial_size = (24, 32)
...@@ -1251,13 +1255,13 @@ class TestRandomErasing: ...@@ -1251,13 +1255,13 @@ class TestRandomErasing:
w_sentinel = mocker.MagicMock() w_sentinel = mocker.MagicMock()
v_sentinel = mocker.MagicMock() v_sentinel = mocker.MagicMock()
mocker.patch( mocker.patch(
"torchvision.prototype.transforms._augment.RandomErasing._get_params", "torchvision.transforms.v2._augment.RandomErasing._get_params",
return_value=dict(i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel), return_value=dict(i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel),
) )
inpt_sentinel = mocker.MagicMock() inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.prototype.transforms._augment.F.erase") mock = mocker.patch("torchvision.transforms.v2._augment.F.erase")
output = transform(inpt_sentinel) output = transform(inpt_sentinel)
if p: if p:
...@@ -1300,7 +1304,7 @@ class TestToImageTensor: ...@@ -1300,7 +1304,7 @@ class TestToImageTensor:
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch( fn = mocker.patch(
"torchvision.prototype.transforms.functional.to_image_tensor", "torchvision.transforms.v2.functional.to_image_tensor",
return_value=torch.rand(1, 3, 8, 8), return_value=torch.rand(1, 3, 8, 8),
) )
...@@ -1319,7 +1323,7 @@ class TestToImagePIL: ...@@ -1319,7 +1323,7 @@ class TestToImagePIL:
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil") fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil")
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToImagePIL() transform = transforms.ToImagePIL()
...@@ -1336,7 +1340,7 @@ class TestToPILImage: ...@@ -1336,7 +1340,7 @@ class TestToPILImage:
[torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int], [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBox, str, int],
) )
def test__transform(self, inpt_type, mocker): def test__transform(self, inpt_type, mocker):
fn = mocker.patch("torchvision.prototype.transforms.functional.to_image_pil") fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil")
inpt = mocker.MagicMock(spec=inpt_type) inpt = mocker.MagicMock(spec=inpt_type)
transform = transforms.ToPILImage() transform = transforms.ToPILImage()
...@@ -1443,7 +1447,7 @@ class TestRandomIoUCrop: ...@@ -1443,7 +1447,7 @@ class TestRandomIoUCrop:
transform = transforms.RandomIoUCrop(sampler_options=[2.0]) transform = transforms.RandomIoUCrop(sampler_options=[2.0])
image = datapoints.Image(torch.rand(1, 3, 4, 4)) image = datapoints.Image(torch.rand(1, 3, 4, 4))
bboxes = datapoints.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4)) bboxes = datapoints.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4))
label = datapoints.Label(torch.tensor([1])) label = proto_datapoints.Label(torch.tensor([1]))
sample = [image, bboxes, label] sample = [image, bboxes, label]
# Let's mock transform._get_params to control the output: # Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock(return_value={}) transform._get_params = mocker.MagicMock(return_value={})
...@@ -1454,7 +1458,7 @@ class TestRandomIoUCrop: ...@@ -1454,7 +1458,7 @@ class TestRandomIoUCrop:
transform = transforms.RandomIoUCrop() transform = transforms.RandomIoUCrop()
with pytest.raises( with pytest.raises(
TypeError, TypeError,
match="requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels", match="requires input sample to contain tensor or PIL images and bounding boxes",
): ):
transform(torch.tensor(0)) transform(torch.tensor(0))
...@@ -1463,13 +1467,11 @@ class TestRandomIoUCrop: ...@@ -1463,13 +1467,11 @@ class TestRandomIoUCrop:
image = datapoints.Image(torch.rand(3, 32, 24)) image = datapoints.Image(torch.rand(3, 32, 24))
bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), extra_dims=(6,)) bboxes = make_bounding_box(format="XYXY", spatial_size=(32, 24), extra_dims=(6,))
label = datapoints.Label(torch.randint(0, 10, size=(6,)))
ohe_label = datapoints.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
masks = make_detection_mask((32, 24), num_objects=6) masks = make_detection_mask((32, 24), num_objects=6)
sample = [image, bboxes, label, ohe_label, masks] sample = [image, bboxes, masks]
fn = mocker.patch("torchvision.prototype.transforms.functional.crop", side_effect=lambda x, **params: x) fn = mocker.patch("torchvision.transforms.v2.functional.crop", side_effect=lambda x, **params: x)
is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool) is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool)
params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area)
...@@ -1493,17 +1495,7 @@ class TestRandomIoUCrop: ...@@ -1493,17 +1495,7 @@ class TestRandomIoUCrop:
assert isinstance(output_bboxes, datapoints.BoundingBox) assert isinstance(output_bboxes, datapoints.BoundingBox)
assert len(output_bboxes) == expected_within_targets assert len(output_bboxes) == expected_within_targets
# check labels output_masks = output[2]
output_label = output[2]
assert isinstance(output_label, datapoints.Label)
assert len(output_label) == expected_within_targets
torch.testing.assert_close(output_label, label[is_within_crop_area])
output_ohe_label = output[3]
assert isinstance(output_ohe_label, datapoints.OneHotLabel)
torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area])
output_masks = output[4]
assert isinstance(output_masks, datapoints.Mask) assert isinstance(output_masks, datapoints.Mask)
assert len(output_masks) == expected_within_targets assert len(output_masks) == expected_within_targets
...@@ -1545,12 +1537,12 @@ class TestScaleJitter: ...@@ -1545,12 +1537,12 @@ class TestScaleJitter:
size_sentinel = mocker.MagicMock() size_sentinel = mocker.MagicMock()
mocker.patch( mocker.patch(
"torchvision.prototype.transforms._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel) "torchvision.transforms.v2._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel)
) )
inpt_sentinel = mocker.MagicMock() inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel) transform(inpt_sentinel)
mock.assert_called_once_with( mock.assert_called_once_with(
...@@ -1592,13 +1584,13 @@ class TestRandomShortestSize: ...@@ -1592,13 +1584,13 @@ class TestRandomShortestSize:
size_sentinel = mocker.MagicMock() size_sentinel = mocker.MagicMock()
mocker.patch( mocker.patch(
"torchvision.prototype.transforms._geometry.RandomShortestSize._get_params", "torchvision.transforms.v2._geometry.RandomShortestSize._get_params",
return_value=dict(size=size_sentinel), return_value=dict(size=size_sentinel),
) )
inpt_sentinel = mocker.MagicMock() inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel) transform(inpt_sentinel)
mock.assert_called_once_with( mock.assert_called_once_with(
...@@ -1613,13 +1605,13 @@ class TestSimpleCopyPaste: ...@@ -1613,13 +1605,13 @@ class TestSimpleCopyPaste:
return mocker.MagicMock(spec=image_type) return mocker.MagicMock(spec=image_type)
def test__extract_image_targets_assertion(self, mocker): def test__extract_image_targets_assertion(self, mocker):
transform = transforms.SimpleCopyPaste() transform = proto_transforms.SimpleCopyPaste()
flat_sample = [ flat_sample = [
# images, batch size = 2 # images, batch size = 2
self.create_fake_image(mocker, datapoints.Image), self.create_fake_image(mocker, datapoints.Image),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=datapoints.Label), mocker.MagicMock(spec=proto_datapoints.Label),
mocker.MagicMock(spec=datapoints.BoundingBox), mocker.MagicMock(spec=datapoints.BoundingBox),
mocker.MagicMock(spec=datapoints.Mask), mocker.MagicMock(spec=datapoints.Mask),
# labels, bboxes, masks # labels, bboxes, masks
...@@ -1631,9 +1623,9 @@ class TestSimpleCopyPaste: ...@@ -1631,9 +1623,9 @@ class TestSimpleCopyPaste:
transform._extract_image_targets(flat_sample) transform._extract_image_targets(flat_sample)
@pytest.mark.parametrize("image_type", [datapoints.Image, PIL.Image.Image, torch.Tensor]) @pytest.mark.parametrize("image_type", [datapoints.Image, PIL.Image.Image, torch.Tensor])
@pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel]) @pytest.mark.parametrize("label_type", [proto_datapoints.Label, proto_datapoints.OneHotLabel])
def test__extract_image_targets(self, image_type, label_type, mocker): def test__extract_image_targets(self, image_type, label_type, mocker):
transform = transforms.SimpleCopyPaste() transform = proto_transforms.SimpleCopyPaste()
flat_sample = [ flat_sample = [
# images, batch size = 2 # images, batch size = 2
...@@ -1669,7 +1661,7 @@ class TestSimpleCopyPaste: ...@@ -1669,7 +1661,7 @@ class TestSimpleCopyPaste:
assert isinstance(target[key], type_) assert isinstance(target[key], type_)
assert target[key] in flat_sample assert target[key] in flat_sample
@pytest.mark.parametrize("label_type", [datapoints.Label, datapoints.OneHotLabel]) @pytest.mark.parametrize("label_type", [proto_datapoints.Label, proto_datapoints.OneHotLabel])
def test__copy_paste(self, label_type): def test__copy_paste(self, label_type):
image = 2 * torch.ones(3, 32, 32) image = 2 * torch.ones(3, 32, 32)
masks = torch.zeros(2, 32, 32) masks = torch.zeros(2, 32, 32)
...@@ -1679,7 +1671,7 @@ class TestSimpleCopyPaste: ...@@ -1679,7 +1671,7 @@ class TestSimpleCopyPaste:
blending = True blending = True
resize_interpolation = InterpolationMode.BILINEAR resize_interpolation = InterpolationMode.BILINEAR
antialias = None antialias = None
if label_type == datapoints.OneHotLabel: if label_type == proto_datapoints.OneHotLabel:
labels = torch.nn.functional.one_hot(labels, num_classes=5) labels = torch.nn.functional.one_hot(labels, num_classes=5)
target = { target = {
"boxes": datapoints.BoundingBox( "boxes": datapoints.BoundingBox(
...@@ -1694,7 +1686,7 @@ class TestSimpleCopyPaste: ...@@ -1694,7 +1686,7 @@ class TestSimpleCopyPaste:
paste_masks[0, 13:19, 12:18] = 1 paste_masks[0, 13:19, 12:18] = 1
paste_masks[1, 15:19, 1:8] = 1 paste_masks[1, 15:19, 1:8] = 1
paste_labels = torch.tensor([3, 4]) paste_labels = torch.tensor([3, 4])
if label_type == datapoints.OneHotLabel: if label_type == proto_datapoints.OneHotLabel:
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
paste_target = { paste_target = {
"boxes": datapoints.BoundingBox( "boxes": datapoints.BoundingBox(
...@@ -1704,7 +1696,7 @@ class TestSimpleCopyPaste: ...@@ -1704,7 +1696,7 @@ class TestSimpleCopyPaste:
"labels": label_type(paste_labels), "labels": label_type(paste_labels),
} }
transform = transforms.SimpleCopyPaste() transform = proto_transforms.SimpleCopyPaste()
random_selection = torch.tensor([0, 1]) random_selection = torch.tensor([0, 1])
output_image, output_target = transform._copy_paste( output_image, output_target = transform._copy_paste(
image, target, paste_image, paste_target, random_selection, blending, resize_interpolation, antialias image, target, paste_image, paste_target, random_selection, blending, resize_interpolation, antialias
...@@ -1716,7 +1708,7 @@ class TestSimpleCopyPaste: ...@@ -1716,7 +1708,7 @@ class TestSimpleCopyPaste:
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"]) torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
expected_labels = torch.tensor([1, 2, 3, 4]) expected_labels = torch.tensor([1, 2, 3, 4])
if label_type == datapoints.OneHotLabel: if label_type == proto_datapoints.OneHotLabel:
expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5) expected_labels = torch.nn.functional.one_hot(expected_labels, num_classes=5)
torch.testing.assert_close(output_target["labels"], label_type(expected_labels)) torch.testing.assert_close(output_target["labels"], label_type(expected_labels))
...@@ -1731,7 +1723,7 @@ class TestFixedSizeCrop: ...@@ -1731,7 +1723,7 @@ class TestFixedSizeCrop:
batch_shape = (10,) batch_shape = (10,)
spatial_size = (11, 5) spatial_size = (11, 5)
transform = transforms.FixedSizeCrop(size=crop_size) transform = proto_transforms.FixedSizeCrop(size=crop_size)
flat_inputs = [ flat_inputs = [
make_image(size=spatial_size, color_space="RGB"), make_image(size=spatial_size, color_space="RGB"),
...@@ -1759,9 +1751,8 @@ class TestFixedSizeCrop: ...@@ -1759,9 +1751,8 @@ class TestFixedSizeCrop:
fill_sentinel = 12 fill_sentinel = 12
padding_mode_sentinel = mocker.MagicMock() padding_mode_sentinel = mocker.MagicMock()
transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) transform = proto_transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
transform._transformed_types = (mocker.MagicMock,) transform._transformed_types = (mocker.MagicMock,)
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
needs_crop, needs_pad = needs needs_crop, needs_pad = needs
...@@ -1810,7 +1801,7 @@ class TestFixedSizeCrop: ...@@ -1810,7 +1801,7 @@ class TestFixedSizeCrop:
if not needs_crop: if not needs_crop:
assert args[0] is inpt_sentinel assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel assert args[1] is padding_sentinel
fill_sentinel = transforms._utils._convert_fill_arg(fill_sentinel) fill_sentinel = _convert_fill_arg(fill_sentinel)
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else: else:
mock_pad.assert_not_called() mock_pad.assert_not_called()
...@@ -1839,8 +1830,7 @@ class TestFixedSizeCrop: ...@@ -1839,8 +1830,7 @@ class TestFixedSizeCrop:
masks = make_detection_mask(size=spatial_size, extra_dims=(batch_size,)) masks = make_detection_mask(size=spatial_size, extra_dims=(batch_size,))
labels = make_label(extra_dims=(batch_size,)) labels = make_label(extra_dims=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1)) transform = proto_transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
output = transform( output = transform(
...@@ -1877,8 +1867,7 @@ class TestFixedSizeCrop: ...@@ -1877,8 +1867,7 @@ class TestFixedSizeCrop:
) )
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box") mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box")
transform = transforms.FixedSizeCrop((-1, -1)) transform = proto_transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
transform(bounding_box) transform(bounding_box)
...@@ -1922,10 +1911,10 @@ class TestLinearTransformation: ...@@ -1922,10 +1911,10 @@ class TestLinearTransformation:
class TestLabelToOneHot: class TestLabelToOneHot:
def test__transform(self): def test__transform(self):
categories = ["apple", "pear", "pineapple"] categories = ["apple", "pear", "pineapple"]
labels = datapoints.Label(torch.tensor([0, 1, 2, 1]), categories=categories) labels = proto_datapoints.Label(torch.tensor([0, 1, 2, 1]), categories=categories)
transform = transforms.LabelToOneHot() transform = proto_transforms.LabelToOneHot()
ohe_labels = transform(labels) ohe_labels = transform(labels)
assert isinstance(ohe_labels, datapoints.OneHotLabel) assert isinstance(ohe_labels, proto_datapoints.OneHotLabel)
assert ohe_labels.shape == (4, 3) assert ohe_labels.shape == (4, 3)
assert ohe_labels.categories == labels.categories == categories assert ohe_labels.categories == labels.categories == categories
...@@ -1956,13 +1945,13 @@ class TestRandomResize: ...@@ -1956,13 +1945,13 @@ class TestRandomResize:
size_sentinel = mocker.MagicMock() size_sentinel = mocker.MagicMock()
mocker.patch( mocker.patch(
"torchvision.prototype.transforms._geometry.RandomResize._get_params", "torchvision.transforms.v2._geometry.RandomResize._get_params",
return_value=dict(size=size_sentinel), return_value=dict(size=size_sentinel),
) )
inpt_sentinel = mocker.MagicMock() inpt_sentinel = mocker.MagicMock()
mock_resize = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") mock_resize = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel) transform(inpt_sentinel)
mock_resize.assert_called_with( mock_resize.assert_called_with(
...@@ -2048,7 +2037,7 @@ class TestPermuteDimensions: ...@@ -2048,7 +2037,7 @@ class TestPermuteDimensions:
int=0, int=0,
) )
transform = transforms.PermuteDimensions(dims) transform = proto_transforms.PermuteDimensions(dims)
transformed_sample = transform(sample) transformed_sample = transform(sample)
for key, value in sample.items(): for key, value in sample.items():
...@@ -2056,7 +2045,7 @@ class TestPermuteDimensions: ...@@ -2056,7 +2045,7 @@ class TestPermuteDimensions:
transformed_value = transformed_sample[key] transformed_value = transformed_sample[key]
if check_type( if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) value, (datapoints.Image, torchvision.transforms.v2.utils.is_simple_tensor, datapoints.Video)
): ):
if transform.dims.get(value_type) is not None: if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value) assert transformed_value.permute(inverse_dims[value_type]).equal(value)
...@@ -2067,14 +2056,14 @@ class TestPermuteDimensions: ...@@ -2067,14 +2056,14 @@ class TestPermuteDimensions:
@pytest.mark.filterwarnings("error") @pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self): def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4)) tensor = torch.empty((2, 3, 4))
transform = transforms.PermuteDimensions(dims=(1, 2, 0)) transform = proto_transforms.PermuteDimensions(dims=(1, 2, 0))
assert transform(tensor).shape == (3, 4, 2) assert transform(tensor).shape == (3, 4, 2)
@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type): def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) proto_transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
class TestTransposeDimensions: class TestTransposeDimensions:
...@@ -2094,7 +2083,7 @@ class TestTransposeDimensions: ...@@ -2094,7 +2083,7 @@ class TestTransposeDimensions:
int=0, int=0,
) )
transform = transforms.TransposeDimensions(dims) transform = proto_transforms.TransposeDimensions(dims)
transformed_sample = transform(sample) transformed_sample = transform(sample)
for key, value in sample.items(): for key, value in sample.items():
...@@ -2103,7 +2092,7 @@ class TestTransposeDimensions: ...@@ -2103,7 +2092,7 @@ class TestTransposeDimensions:
transposed_dims = transform.dims.get(value_type) transposed_dims = transform.dims.get(value_type)
if check_type( if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) value, (datapoints.Image, torchvision.transforms.v2.utils.is_simple_tensor, datapoints.Video)
): ):
if transposed_dims is not None: if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value) assert transformed_value.transpose(*transposed_dims).equal(value)
...@@ -2114,14 +2103,14 @@ class TestTransposeDimensions: ...@@ -2114,14 +2103,14 @@ class TestTransposeDimensions:
@pytest.mark.filterwarnings("error") @pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self): def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4)) tensor = torch.empty((2, 3, 4))
transform = transforms.TransposeDimensions(dims=(0, 2)) transform = proto_transforms.TransposeDimensions(dims=(0, 2))
assert transform(tensor).shape == (4, 3, 2) assert transform(tensor).shape == (4, 3, 2)
@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type): def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) proto_transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})
class TestUniformTemporalSubsample: class TestUniformTemporalSubsample:
......
...@@ -12,6 +12,8 @@ import PIL.Image ...@@ -12,6 +12,8 @@ import PIL.Image
import pytest import pytest
import torch import torch
import torchvision.prototype.transforms as prototype_transforms
import torchvision.transforms.v2 as v2_transforms
from prototype_common_utils import ( from prototype_common_utils import (
ArgsKwargs, ArgsKwargs,
assert_close, assert_close,
...@@ -24,13 +26,13 @@ from prototype_common_utils import ( ...@@ -24,13 +26,13 @@ from prototype_common_utils import (
make_segmentation_mask, make_segmentation_mask,
) )
from torch import nn from torch import nn
from torchvision import transforms as legacy_transforms from torchvision import datapoints, transforms as legacy_transforms
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import datapoints, transforms as prototype_transforms
from torchvision.prototype.transforms import functional as prototype_F
from torchvision.prototype.transforms.functional import to_image_pil
from torchvision.prototype.transforms.utils import query_spatial_size
from torchvision.transforms import functional as legacy_F from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import query_spatial_size
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)]) DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)])
...@@ -71,7 +73,7 @@ LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * ...@@ -71,7 +73,7 @@ LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] *
CONSISTENCY_CONFIGS = [ CONSISTENCY_CONFIGS = [
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.Normalize, v2_transforms.Normalize,
legacy_transforms.Normalize, legacy_transforms.Normalize,
[ [
ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ArgsKwargs(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
...@@ -80,14 +82,14 @@ CONSISTENCY_CONFIGS = [ ...@@ -80,14 +82,14 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]), make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.float]),
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.Resize, v2_transforms.Resize,
legacy_transforms.Resize, legacy_transforms.Resize,
[ [
NotScriptableArgsKwargs(32), NotScriptableArgsKwargs(32),
ArgsKwargs([32]), ArgsKwargs([32]),
ArgsKwargs((32, 29)), ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST), ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR), ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC), ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC),
...@@ -100,7 +102,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -100,7 +102,7 @@ CONSISTENCY_CONFIGS = [
], ],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.CenterCrop, v2_transforms.CenterCrop,
legacy_transforms.CenterCrop, legacy_transforms.CenterCrop,
[ [
ArgsKwargs(18), ArgsKwargs(18),
...@@ -108,7 +110,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -108,7 +110,7 @@ CONSISTENCY_CONFIGS = [
], ],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.FiveCrop, v2_transforms.FiveCrop,
legacy_transforms.FiveCrop, legacy_transforms.FiveCrop,
[ [
ArgsKwargs(18), ArgsKwargs(18),
...@@ -117,7 +119,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -117,7 +119,7 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]), make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.TenCrop, v2_transforms.TenCrop,
legacy_transforms.TenCrop, legacy_transforms.TenCrop,
[ [
ArgsKwargs(18), ArgsKwargs(18),
...@@ -127,7 +129,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -127,7 +129,7 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]), make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(20, 19)]),
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.Pad, v2_transforms.Pad,
legacy_transforms.Pad, legacy_transforms.Pad,
[ [
NotScriptableArgsKwargs(3), NotScriptableArgsKwargs(3),
...@@ -143,7 +145,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -143,7 +145,7 @@ CONSISTENCY_CONFIGS = [
), ),
*[ *[
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.LinearTransformation, v2_transforms.LinearTransformation,
legacy_transforms.LinearTransformation, legacy_transforms.LinearTransformation,
[ [
ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)), ArgsKwargs(LINEAR_TRANSFORMATION_MATRIX.to(matrix_dtype), LINEAR_TRANSFORMATION_MEAN.to(matrix_dtype)),
...@@ -164,7 +166,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -164,7 +166,7 @@ CONSISTENCY_CONFIGS = [
] ]
], ],
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.Grayscale, v2_transforms.Grayscale,
legacy_transforms.Grayscale, legacy_transforms.Grayscale,
[ [
ArgsKwargs(num_output_channels=1), ArgsKwargs(num_output_channels=1),
...@@ -175,7 +177,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -175,7 +177,7 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs=dict(rtol=None, atol=None), closeness_kwargs=dict(rtol=None, atol=None),
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.ConvertDtype, v2_transforms.ConvertDtype,
legacy_transforms.ConvertImageDtype, legacy_transforms.ConvertImageDtype,
[ [
ArgsKwargs(torch.float16), ArgsKwargs(torch.float16),
...@@ -189,7 +191,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -189,7 +191,7 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs=dict(rtol=None, atol=None), closeness_kwargs=dict(rtol=None, atol=None),
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.ToPILImage, v2_transforms.ToPILImage,
legacy_transforms.ToPILImage, legacy_transforms.ToPILImage,
[NotScriptableArgsKwargs()], [NotScriptableArgsKwargs()],
make_images_kwargs=dict( make_images_kwargs=dict(
...@@ -204,7 +206,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -204,7 +206,7 @@ CONSISTENCY_CONFIGS = [
supports_pil=False, supports_pil=False,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.Lambda, v2_transforms.Lambda,
legacy_transforms.Lambda, legacy_transforms.Lambda,
[ [
NotScriptableArgsKwargs(lambda image: image / 2), NotScriptableArgsKwargs(lambda image: image / 2),
...@@ -214,7 +216,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -214,7 +216,7 @@ CONSISTENCY_CONFIGS = [
supports_pil=False, supports_pil=False,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomHorizontalFlip, v2_transforms.RandomHorizontalFlip,
legacy_transforms.RandomHorizontalFlip, legacy_transforms.RandomHorizontalFlip,
[ [
ArgsKwargs(p=0), ArgsKwargs(p=0),
...@@ -222,7 +224,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -222,7 +224,7 @@ CONSISTENCY_CONFIGS = [
], ],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomVerticalFlip, v2_transforms.RandomVerticalFlip,
legacy_transforms.RandomVerticalFlip, legacy_transforms.RandomVerticalFlip,
[ [
ArgsKwargs(p=0), ArgsKwargs(p=0),
...@@ -230,7 +232,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -230,7 +232,7 @@ CONSISTENCY_CONFIGS = [
], ],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomEqualize, v2_transforms.RandomEqualize,
legacy_transforms.RandomEqualize, legacy_transforms.RandomEqualize,
[ [
ArgsKwargs(p=0), ArgsKwargs(p=0),
...@@ -239,7 +241,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -239,7 +241,7 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]), make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomInvert, v2_transforms.RandomInvert,
legacy_transforms.RandomInvert, legacy_transforms.RandomInvert,
[ [
ArgsKwargs(p=0), ArgsKwargs(p=0),
...@@ -247,7 +249,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -247,7 +249,7 @@ CONSISTENCY_CONFIGS = [
], ],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomPosterize, v2_transforms.RandomPosterize,
legacy_transforms.RandomPosterize, legacy_transforms.RandomPosterize,
[ [
ArgsKwargs(p=0, bits=5), ArgsKwargs(p=0, bits=5),
...@@ -257,7 +259,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -257,7 +259,7 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]), make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, dtypes=[torch.uint8]),
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomSolarize, v2_transforms.RandomSolarize,
legacy_transforms.RandomSolarize, legacy_transforms.RandomSolarize,
[ [
ArgsKwargs(p=0, threshold=0.5), ArgsKwargs(p=0, threshold=0.5),
...@@ -267,7 +269,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -267,7 +269,7 @@ CONSISTENCY_CONFIGS = [
), ),
*[ *[
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomAutocontrast, v2_transforms.RandomAutocontrast,
legacy_transforms.RandomAutocontrast, legacy_transforms.RandomAutocontrast,
[ [
ArgsKwargs(p=0), ArgsKwargs(p=0),
...@@ -279,7 +281,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -279,7 +281,7 @@ CONSISTENCY_CONFIGS = [
for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))] for dt, ckw in [(torch.uint8, dict(atol=1, rtol=0)), (torch.float32, dict(rtol=None, atol=None))]
], ],
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomAdjustSharpness, v2_transforms.RandomAdjustSharpness,
legacy_transforms.RandomAdjustSharpness, legacy_transforms.RandomAdjustSharpness,
[ [
ArgsKwargs(p=0, sharpness_factor=0.5), ArgsKwargs(p=0, sharpness_factor=0.5),
...@@ -289,7 +291,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -289,7 +291,7 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs={"atol": 1e-6, "rtol": 1e-6}, closeness_kwargs={"atol": 1e-6, "rtol": 1e-6},
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomGrayscale, v2_transforms.RandomGrayscale,
legacy_transforms.RandomGrayscale, legacy_transforms.RandomGrayscale,
[ [
ArgsKwargs(p=0), ArgsKwargs(p=0),
...@@ -300,14 +302,14 @@ CONSISTENCY_CONFIGS = [ ...@@ -300,14 +302,14 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs=dict(rtol=None, atol=None), closeness_kwargs=dict(rtol=None, atol=None),
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomResizedCrop, v2_transforms.RandomResizedCrop,
legacy_transforms.RandomResizedCrop, legacy_transforms.RandomResizedCrop,
[ [
ArgsKwargs(16), ArgsKwargs(16),
ArgsKwargs(17, scale=(0.3, 0.7)), ArgsKwargs(17, scale=(0.3, 0.7)),
ArgsKwargs(25, ratio=(0.5, 1.5)), ArgsKwargs(25, ratio=(0.5, 1.5)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC), ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST), ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC), ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC),
ArgsKwargs((29, 32), antialias=False), ArgsKwargs((29, 32), antialias=False),
...@@ -315,7 +317,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -315,7 +317,7 @@ CONSISTENCY_CONFIGS = [
], ],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomErasing, v2_transforms.RandomErasing,
legacy_transforms.RandomErasing, legacy_transforms.RandomErasing,
[ [
ArgsKwargs(p=0), ArgsKwargs(p=0),
...@@ -329,7 +331,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -329,7 +331,7 @@ CONSISTENCY_CONFIGS = [
supports_pil=False, supports_pil=False,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.ColorJitter, v2_transforms.ColorJitter,
legacy_transforms.ColorJitter, legacy_transforms.ColorJitter,
[ [
ArgsKwargs(), ArgsKwargs(),
...@@ -347,7 +349,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -347,7 +349,7 @@ CONSISTENCY_CONFIGS = [
), ),
*[ *[
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.ElasticTransform, v2_transforms.ElasticTransform,
legacy_transforms.ElasticTransform, legacy_transforms.ElasticTransform,
[ [
ArgsKwargs(), ArgsKwargs(),
...@@ -355,8 +357,8 @@ CONSISTENCY_CONFIGS = [ ...@@ -355,8 +357,8 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(alpha=(15.3, 27.2)), ArgsKwargs(alpha=(15.3, 27.2)),
ArgsKwargs(sigma=3.0), ArgsKwargs(sigma=3.0),
ArgsKwargs(sigma=(2.5, 3.9)), ArgsKwargs(sigma=(2.5, 3.9)),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs(interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC), ArgsKwargs(interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs(interpolation=PIL.Image.NEAREST), ArgsKwargs(interpolation=PIL.Image.NEAREST),
ArgsKwargs(interpolation=PIL.Image.BICUBIC), ArgsKwargs(interpolation=PIL.Image.BICUBIC),
ArgsKwargs(fill=1), ArgsKwargs(fill=1),
...@@ -370,7 +372,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -370,7 +372,7 @@ CONSISTENCY_CONFIGS = [
for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})] for dt, ckw in [(torch.uint8, {"rtol": 1e-1, "atol": 1}), (torch.float32, {"rtol": 1e-2, "atol": 1e-3})]
], ],
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.GaussianBlur, v2_transforms.GaussianBlur,
legacy_transforms.GaussianBlur, legacy_transforms.GaussianBlur,
[ [
ArgsKwargs(kernel_size=3), ArgsKwargs(kernel_size=3),
...@@ -381,7 +383,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -381,7 +383,7 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs={"rtol": 1e-5, "atol": 1e-5}, closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomAffine, v2_transforms.RandomAffine,
legacy_transforms.RandomAffine, legacy_transforms.RandomAffine,
[ [
ArgsKwargs(degrees=30.0), ArgsKwargs(degrees=30.0),
...@@ -392,7 +394,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -392,7 +394,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(degrees=0.0, shear=(8, 17)), ArgsKwargs(degrees=0.0, shear=(8, 17)),
ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)), ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)), ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST), ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
ArgsKwargs(degrees=30.0, fill=1), ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)), ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
...@@ -401,7 +403,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -401,7 +403,7 @@ CONSISTENCY_CONFIGS = [
removed_params=["fillcolor", "resample"], removed_params=["fillcolor", "resample"],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomCrop, v2_transforms.RandomCrop,
legacy_transforms.RandomCrop, legacy_transforms.RandomCrop,
[ [
ArgsKwargs(12), ArgsKwargs(12),
...@@ -421,13 +423,13 @@ CONSISTENCY_CONFIGS = [ ...@@ -421,13 +423,13 @@ CONSISTENCY_CONFIGS = [
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(26, 26), (18, 33), (29, 22)]), make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, sizes=[(26, 26), (18, 33), (29, 22)]),
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomPerspective, v2_transforms.RandomPerspective,
legacy_transforms.RandomPerspective, legacy_transforms.RandomPerspective,
[ [
ArgsKwargs(p=0), ArgsKwargs(p=0),
ArgsKwargs(p=1), ArgsKwargs(p=1),
ArgsKwargs(p=1, distortion_scale=0.3), ArgsKwargs(p=1, distortion_scale=0.3),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST), ArgsKwargs(p=1, distortion_scale=0.2, interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST), ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.1, fill=1), ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)), ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
...@@ -435,12 +437,12 @@ CONSISTENCY_CONFIGS = [ ...@@ -435,12 +437,12 @@ CONSISTENCY_CONFIGS = [
closeness_kwargs={"atol": None, "rtol": None}, closeness_kwargs={"atol": None, "rtol": None},
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomRotation, v2_transforms.RandomRotation,
legacy_transforms.RandomRotation, legacy_transforms.RandomRotation,
[ [
ArgsKwargs(degrees=30.0), ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)), ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR), ArgsKwargs(degrees=30.0, interpolation=v2_transforms.InterpolationMode.BILINEAR),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR), ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
ArgsKwargs(degrees=30.0, expand=True), ArgsKwargs(degrees=30.0, expand=True),
ArgsKwargs(degrees=30.0, center=(0, 0)), ArgsKwargs(degrees=30.0, center=(0, 0)),
...@@ -450,43 +452,43 @@ CONSISTENCY_CONFIGS = [ ...@@ -450,43 +452,43 @@ CONSISTENCY_CONFIGS = [
removed_params=["resample"], removed_params=["resample"],
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.PILToTensor, v2_transforms.PILToTensor,
legacy_transforms.PILToTensor, legacy_transforms.PILToTensor,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.ToTensor, v2_transforms.ToTensor,
legacy_transforms.ToTensor, legacy_transforms.ToTensor,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.Compose, v2_transforms.Compose,
legacy_transforms.Compose, legacy_transforms.Compose,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomApply, v2_transforms.RandomApply,
legacy_transforms.RandomApply, legacy_transforms.RandomApply,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomChoice, v2_transforms.RandomChoice,
legacy_transforms.RandomChoice, legacy_transforms.RandomChoice,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomOrder, v2_transforms.RandomOrder,
legacy_transforms.RandomOrder, legacy_transforms.RandomOrder,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.AugMix, v2_transforms.AugMix,
legacy_transforms.AugMix, legacy_transforms.AugMix,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.AutoAugment, v2_transforms.AutoAugment,
legacy_transforms.AutoAugment, legacy_transforms.AutoAugment,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandAugment, v2_transforms.RandAugment,
legacy_transforms.RandAugment, legacy_transforms.RandAugment,
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.TrivialAugmentWide, v2_transforms.TrivialAugmentWide,
legacy_transforms.TrivialAugmentWide, legacy_transforms.TrivialAugmentWide,
), ),
] ]
...@@ -680,19 +682,19 @@ get_params_parametrization = pytest.mark.parametrize( ...@@ -680,19 +682,19 @@ get_params_parametrization = pytest.mark.parametrize(
id=transform_cls.__name__, id=transform_cls.__name__,
) )
for transform_cls, get_params_args_kwargs in [ for transform_cls, get_params_args_kwargs in [
(prototype_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])), (v2_transforms.RandomResizedCrop, ArgsKwargs(make_image(), scale=[0.3, 0.7], ratio=[0.5, 1.5])),
(prototype_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))), (v2_transforms.RandomErasing, ArgsKwargs(make_image(), scale=(0.3, 0.7), ratio=(0.5, 1.5))),
(prototype_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)), (v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(prototype_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])), (v2_transforms.ElasticTransform, ArgsKwargs(alpha=[15.3, 27.2], sigma=[2.5, 3.9], size=[17, 31])),
(prototype_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)), (v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
( (
prototype_transforms.RandomAffine, v2_transforms.RandomAffine,
ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]), ArgsKwargs(degrees=[-20.0, 10.0], translate=None, scale_ranges=None, shears=None, img_size=[15, 29]),
), ),
(prototype_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))), (v2_transforms.RandomCrop, ArgsKwargs(make_image(size=(61, 47)), output_size=(19, 25))),
(prototype_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)), (v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
(prototype_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])), (v2_transforms.RandomRotation, ArgsKwargs(degrees=[-20.0, 10.0])),
(prototype_transforms.AutoAugment, ArgsKwargs(5)), (v2_transforms.AutoAugment, ArgsKwargs(5)),
] ]
], ],
) )
...@@ -767,10 +769,10 @@ class TestContainerTransforms: ...@@ -767,10 +769,10 @@ class TestContainerTransforms:
""" """
def test_compose(self): def test_compose(self):
prototype_transform = prototype_transforms.Compose( prototype_transform = v2_transforms.Compose(
[ [
prototype_transforms.Resize(256), v2_transforms.Resize(256),
prototype_transforms.CenterCrop(224), v2_transforms.CenterCrop(224),
] ]
) )
legacy_transform = legacy_transforms.Compose( legacy_transform = legacy_transforms.Compose(
...@@ -785,11 +787,11 @@ class TestContainerTransforms: ...@@ -785,11 +787,11 @@ class TestContainerTransforms:
@pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1]) @pytest.mark.parametrize("p", [0, 0.1, 0.5, 0.9, 1])
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList]) @pytest.mark.parametrize("sequence_type", [list, nn.ModuleList])
def test_random_apply(self, p, sequence_type): def test_random_apply(self, p, sequence_type):
prototype_transform = prototype_transforms.RandomApply( prototype_transform = v2_transforms.RandomApply(
sequence_type( sequence_type(
[ [
prototype_transforms.Resize(256), v2_transforms.Resize(256),
prototype_transforms.CenterCrop(224), v2_transforms.CenterCrop(224),
] ]
), ),
p=p, p=p,
...@@ -814,9 +816,9 @@ class TestContainerTransforms: ...@@ -814,9 +816,9 @@ class TestContainerTransforms:
# We can't test other values for `p` since the random parameter generation is different # We can't test other values for `p` since the random parameter generation is different
@pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)]) @pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
def test_random_choice(self, probabilities): def test_random_choice(self, probabilities):
prototype_transform = prototype_transforms.RandomChoice( prototype_transform = v2_transforms.RandomChoice(
[ [
prototype_transforms.Resize(256), v2_transforms.Resize(256),
legacy_transforms.CenterCrop(224), legacy_transforms.CenterCrop(224),
], ],
probabilities=probabilities, probabilities=probabilities,
...@@ -834,7 +836,7 @@ class TestContainerTransforms: ...@@ -834,7 +836,7 @@ class TestContainerTransforms:
class TestToTensorTransforms: class TestToTensorTransforms:
def test_pil_to_tensor(self): def test_pil_to_tensor(self):
prototype_transform = prototype_transforms.PILToTensor() prototype_transform = v2_transforms.PILToTensor()
legacy_transform = legacy_transforms.PILToTensor() legacy_transform = legacy_transforms.PILToTensor()
for image in make_images(extra_dims=[()]): for image in make_images(extra_dims=[()]):
...@@ -844,7 +846,7 @@ class TestToTensorTransforms: ...@@ -844,7 +846,7 @@ class TestToTensorTransforms:
def test_to_tensor(self): def test_to_tensor(self):
with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")): with pytest.warns(UserWarning, match=re.escape("The transform `ToTensor()` is deprecated")):
prototype_transform = prototype_transforms.ToTensor() prototype_transform = v2_transforms.ToTensor()
legacy_transform = legacy_transforms.ToTensor() legacy_transform = legacy_transforms.ToTensor()
for image in make_images(extra_dims=[()]): for image in make_images(extra_dims=[()]):
...@@ -867,14 +869,14 @@ class TestAATransforms: ...@@ -867,14 +869,14 @@ class TestAATransforms:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"interpolation", "interpolation",
[ [
prototype_transforms.InterpolationMode.NEAREST, v2_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR, v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST, PIL.Image.NEAREST,
], ],
) )
def test_randaug(self, inpt, interpolation, mocker): def test_randaug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1) t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = prototype_transforms.RandAugment(interpolation=interpolation, num_ops=1) t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
le = len(t._AUGMENTATION_SPACE) le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys()) keys = list(t._AUGMENTATION_SPACE.keys())
...@@ -909,14 +911,14 @@ class TestAATransforms: ...@@ -909,14 +911,14 @@ class TestAATransforms:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"interpolation", "interpolation",
[ [
prototype_transforms.InterpolationMode.NEAREST, v2_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR, v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST, PIL.Image.NEAREST,
], ],
) )
def test_trivial_aug(self, inpt, interpolation, mocker): def test_trivial_aug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation) t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
t = prototype_transforms.TrivialAugmentWide(interpolation=interpolation) t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
le = len(t._AUGMENTATION_SPACE) le = len(t._AUGMENTATION_SPACE)
keys = list(t._AUGMENTATION_SPACE.keys()) keys = list(t._AUGMENTATION_SPACE.keys())
...@@ -961,15 +963,15 @@ class TestAATransforms: ...@@ -961,15 +963,15 @@ class TestAATransforms:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"interpolation", "interpolation",
[ [
prototype_transforms.InterpolationMode.NEAREST, v2_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR, v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST, PIL.Image.NEAREST,
], ],
) )
def test_augmix(self, inpt, interpolation, mocker): def test_augmix(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1) t_ref._sample_dirichlet = lambda t: t.softmax(dim=-1)
t = prototype_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1) t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t._sample_dirichlet = lambda t: t.softmax(dim=-1) t._sample_dirichlet = lambda t: t.softmax(dim=-1)
le = len(t._AUGMENTATION_SPACE) le = len(t._AUGMENTATION_SPACE)
...@@ -1014,15 +1016,15 @@ class TestAATransforms: ...@@ -1014,15 +1016,15 @@ class TestAATransforms:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"interpolation", "interpolation",
[ [
prototype_transforms.InterpolationMode.NEAREST, v2_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR, v2_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST, PIL.Image.NEAREST,
], ],
) )
def test_aa(self, inpt, interpolation): def test_aa(self, inpt, interpolation):
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet") aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation) t_ref = legacy_transforms.AutoAugment(aa_policy, interpolation=interpolation)
t = prototype_transforms.AutoAugment(aa_policy, interpolation=interpolation) t = v2_transforms.AutoAugment(aa_policy, interpolation=interpolation)
torch.manual_seed(12) torch.manual_seed(12)
expected_output = t_ref(inpt) expected_output = t_ref(inpt)
...@@ -1087,10 +1089,16 @@ class TestRefDetTransforms: ...@@ -1087,10 +1089,16 @@ class TestRefDetTransforms:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"t_ref, t, data_kwargs", "t_ref, t, data_kwargs",
[ [
(det_transforms.RandomHorizontalFlip(p=1.0), prototype_transforms.RandomHorizontalFlip(p=1.0), {}), (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
(det_transforms.RandomIoUCrop(), prototype_transforms.RandomIoUCrop(), {"with_mask": False}), # FIXME: make
(det_transforms.RandomZoomOut(), prototype_transforms.RandomZoomOut(), {"with_mask": False}), # v2_transforms.Compose([
(det_transforms.ScaleJitter((1024, 1024)), prototype_transforms.ScaleJitter((1024, 1024)), {}), # v2_transforms.RandomIoUCrop(),
# v2_transforms.SanitizeBoundingBoxes()
# ])
# work
# (det_transforms.RandomIoUCrop(), v2_transforms.RandomIoUCrop(), {"with_mask": False}),
(det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
(det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024)), {}),
( (
det_transforms.FixedSizeCrop((1024, 1024), fill=0), det_transforms.FixedSizeCrop((1024, 1024), fill=0),
prototype_transforms.FixedSizeCrop((1024, 1024), fill=0), prototype_transforms.FixedSizeCrop((1024, 1024), fill=0),
...@@ -1100,7 +1108,7 @@ class TestRefDetTransforms: ...@@ -1100,7 +1108,7 @@ class TestRefDetTransforms:
det_transforms.RandomShortestSize( det_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
), ),
prototype_transforms.RandomShortestSize( v2_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333 min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
), ),
{}, {},
...@@ -1127,11 +1135,11 @@ seg_transforms = import_transforms_from_references("segmentation") ...@@ -1127,11 +1135,11 @@ seg_transforms = import_transforms_from_references("segmentation")
# 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name # 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
# counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True` # counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
# 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size. # 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
class PadIfSmaller(prototype_transforms.Transform): class PadIfSmaller(v2_transforms.Transform):
def __init__(self, size, fill=0): def __init__(self, size, fill=0):
super().__init__() super().__init__()
self.size = size self.size = size
self.fill = prototype_transforms._geometry._setup_fill_arg(fill) self.fill = v2_transforms._geometry._setup_fill_arg(fill)
def _get_params(self, sample): def _get_params(self, sample):
height, width = query_spatial_size(sample) height, width = query_spatial_size(sample)
...@@ -1193,27 +1201,27 @@ class TestRefSegTransforms: ...@@ -1193,27 +1201,27 @@ class TestRefSegTransforms:
[ [
( (
seg_transforms.RandomHorizontalFlip(flip_prob=1.0), seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
prototype_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0),
dict(), dict(),
), ),
( (
seg_transforms.RandomHorizontalFlip(flip_prob=0.0), seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
prototype_transforms.RandomHorizontalFlip(p=0.0), v2_transforms.RandomHorizontalFlip(p=0.0),
dict(), dict(),
), ),
( (
seg_transforms.RandomCrop(size=480), seg_transforms.RandomCrop(size=480),
prototype_transforms.Compose( v2_transforms.Compose(
[ [
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})), PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})),
prototype_transforms.RandomCrop(size=480), v2_transforms.RandomCrop(size=480),
] ]
), ),
dict(), dict(),
), ),
( (
seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), v2_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
dict(supports_pil=False, image_dtype=torch.float), dict(supports_pil=False, image_dtype=torch.float),
), ),
], ],
...@@ -1222,7 +1230,7 @@ class TestRefSegTransforms: ...@@ -1222,7 +1230,7 @@ class TestRefSegTransforms:
self.check(t, t_ref, data_kwargs) self.check(t, t_ref, data_kwargs)
def check_resize(self, mocker, t_ref, t): def check_resize(self, mocker, t_ref, t):
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
mock_ref = mocker.patch("torchvision.transforms.functional.resize") mock_ref = mocker.patch("torchvision.transforms.functional.resize")
for dp, dp_ref in self.make_datapoints(): for dp, dp_ref in self.make_datapoints():
...@@ -1263,9 +1271,9 @@ class TestRefSegTransforms: ...@@ -1263,9 +1271,9 @@ class TestRefSegTransforms:
# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported # We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
# normally # normally
t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) t = v2_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
mocker.patch( mocker.patch(
"torchvision.prototype.transforms._geometry.torch.randint", "torchvision.transforms.v2._geometry.torch.randint",
new=patched_randint, new=patched_randint,
) )
...@@ -1277,7 +1285,7 @@ class TestRefSegTransforms: ...@@ -1277,7 +1285,7 @@ class TestRefSegTransforms:
torch.manual_seed(0) torch.manual_seed(0)
base_size = 520 base_size = 520
t = prototype_transforms.Resize(size=base_size, antialias=True) t = v2_transforms.Resize(size=base_size, antialias=True)
t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size) t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size)
......
...@@ -11,7 +11,6 @@ import pytest ...@@ -11,7 +11,6 @@ import pytest
import torch import torch
import torchvision.prototype.transforms.utils
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
from prototype_common_utils import ( from prototype_common_utils import (
assert_close, assert_close,
...@@ -22,11 +21,12 @@ from prototype_common_utils import ( ...@@ -22,11 +21,12 @@ from prototype_common_utils import (
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
from prototype_transforms_kernel_infos import KERNEL_INFOS from prototype_transforms_kernel_infos import KERNEL_INFOS
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torchvision.prototype import datapoints from torchvision import datapoints
from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
from torchvision.prototype.transforms.functional._meta import clamp_bounding_box, convert_format_bounding_box
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_box, convert_format_bounding_box
from torchvision.transforms.v2.utils import is_simple_tensor
KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS} KERNEL_INFOS_MAP = {info.kernel: info for info in KERNEL_INFOS}
...@@ -168,11 +168,7 @@ class TestKernels: ...@@ -168,11 +168,7 @@ class TestKernels:
def test_batched_vs_single(self, test_id, info, args_kwargs, device): def test_batched_vs_single(self, test_id, info, args_kwargs, device):
(batched_input, *other_args), kwargs = args_kwargs.load(device) (batched_input, *other_args), kwargs = args_kwargs.load(device)
datapoint_type = ( datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
datapoints.Image
if torchvision.prototype.transforms.utils.is_simple_tensor(batched_input)
else type(batched_input)
)
# This dictionary contains the number of rightmost dimensions that contain the actual data. # This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension. # Everything to the left is considered a batch dimension.
data_dims = { data_dims = {
......
...@@ -3,12 +3,12 @@ import pytest ...@@ -3,12 +3,12 @@ import pytest
import torch import torch
import torchvision.prototype.transforms.utils import torchvision.transforms.v2.utils
from prototype_common_utils import make_bounding_box, make_detection_mask, make_image from prototype_common_utils import make_bounding_box, make_detection_mask, make_image
from torchvision.prototype import datapoints from torchvision import datapoints
from torchvision.prototype.transforms.functional import to_image_pil from torchvision.transforms.v2.functional import to_image_pil
from torchvision.prototype.transforms.utils import has_all, has_any from torchvision.transforms.v2.utils import has_all, has_any
IMAGE = make_image(color_space="RGB") IMAGE = make_image(color_space="RGB")
...@@ -37,15 +37,15 @@ MASK = make_detection_mask(size=IMAGE.spatial_size) ...@@ -37,15 +37,15 @@ MASK = make_detection_mask(size=IMAGE.spatial_size)
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), True), ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True),
( (
(torch.Tensor(IMAGE),), (torch.Tensor(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
True, True,
), ),
( (
(to_image_pil(IMAGE),), (to_image_pil(IMAGE),),
(datapoints.Image, PIL.Image.Image, torchvision.prototype.transforms.utils.is_simple_tensor), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor),
True, True,
), ),
], ],
......
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
...@@ -105,7 +105,7 @@ class Datapoint(torch.Tensor): ...@@ -105,7 +105,7 @@ class Datapoint(torch.Tensor):
# the class. This approach avoids the DataLoader issue described at # the class. This approach avoids the DataLoader issue described at
# https://github.com/pytorch/vision/pull/6476#discussion_r953588621 # https://github.com/pytorch/vision/pull/6476#discussion_r953588621
if Datapoint.__F is None: if Datapoint.__F is None:
from ..transforms import functional from ..transforms.v2 import functional
Datapoint.__F = functional Datapoint.__F = functional
return Datapoint.__F return Datapoint.__F
......
...@@ -8,9 +8,8 @@ from collections import defaultdict ...@@ -8,9 +8,8 @@ from collections import defaultdict
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torchvision import datasets from torchvision import datapoints, datasets
from torchvision.prototype import datapoints from torchvision.transforms.v2 import functional as F
from torchvision.prototype.transforms import functional as F
__all__ = ["wrap_dataset_for_transforms_v2"] __all__ = ["wrap_dataset_for_transforms_v2"]
......
...@@ -24,7 +24,7 @@ class Image(Datapoint): ...@@ -24,7 +24,7 @@ class Image(Datapoint):
requires_grad: Optional[bool] = None, requires_grad: Optional[bool] = None,
) -> Image: ) -> Image:
if isinstance(data, PIL.Image.Image): if isinstance(data, PIL.Image.Image):
from torchvision.prototype.transforms import functional as F from torchvision.transforms.v2 import functional as F
data = F.pil_to_tensor(data) data = F.pil_to_tensor(data)
......
...@@ -23,7 +23,7 @@ class Mask(Datapoint): ...@@ -23,7 +23,7 @@ class Mask(Datapoint):
requires_grad: Optional[bool] = None, requires_grad: Optional[bool] = None,
) -> Mask: ) -> Mask:
if isinstance(data, PIL.Image.Image): if isinstance(data, PIL.Image.Image):
from torchvision.prototype.transforms import functional as F from torchvision.transforms.v2 import functional as F
data = F.pil_to_tensor(data) data = F.pil_to_tensor(data)
......
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
...@@ -5,7 +5,7 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union ...@@ -5,7 +5,7 @@ from typing import Any, Optional, Sequence, Type, TypeVar, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from ._datapoint import Datapoint from torchvision.datapoints._datapoint import Datapoint
L = TypeVar("L", bound="_LabelBase") L = TypeVar("L", bound="_LabelBase")
......
...@@ -6,7 +6,8 @@ import numpy as np ...@@ -6,7 +6,8 @@ import numpy as np
import torch import torch
from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper from torchdata.datapipes.iter import Filter, IterDataPipe, IterKeyZipper, Mapper
from torchvision.prototype.datapoints import BoundingBox, Label from torchvision.datapoints import BoundingBox
from torchvision.prototype.datapoints import Label
from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, EncodedImage, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment