Unverified Commit db310636 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

move passthrough for unknown types from dispatchers to transforms (#7804)

parent 87681314
import itertools
import re import re
import PIL.Image import PIL.Image
...@@ -19,7 +17,6 @@ from prototype_common_utils import make_label ...@@ -19,7 +17,6 @@ from prototype_common_utils import make_label
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2._utils import _convert_fill_arg
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil
from torchvision.transforms.v2.utils import check_type, is_simple_tensor from torchvision.transforms.v2.utils import check_type, is_simple_tensor
...@@ -187,66 +184,6 @@ class TestFixedSizeCrop: ...@@ -187,66 +184,6 @@ class TestFixedSizeCrop:
assert params["needs_pad"] assert params["needs_pad"]
assert any(pad > 0 for pad in params["padding"]) assert any(pad > 0 for pad in params["padding"])
@pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2)))
def test__transform(self, mocker, needs):
fill_sentinel = 12
padding_mode_sentinel = mocker.MagicMock()
transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
transform._transformed_types = (mocker.MagicMock,)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
needs_crop, needs_pad = needs
top_sentinel = mocker.MagicMock()
left_sentinel = mocker.MagicMock()
height_sentinel = mocker.MagicMock()
width_sentinel = mocker.MagicMock()
is_valid = mocker.MagicMock() if needs_crop else None
padding_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=needs_crop,
top=top_sentinel,
left=left_sentinel,
height=height_sentinel,
width=width_sentinel,
is_valid=is_valid,
padding=padding_sentinel,
needs_pad=needs_pad,
),
)
inpt_sentinel = mocker.MagicMock()
mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop")
mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad")
transform(inpt_sentinel)
if needs_crop:
mock_crop.assert_called_once_with(
inpt_sentinel,
top=top_sentinel,
left=left_sentinel,
height=height_sentinel,
width=width_sentinel,
)
else:
mock_crop.assert_not_called()
if needs_pad:
# If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use
# `MagicMock.assert_called_once_with` and have to perform the checks manually
mock_pad.assert_called_once()
args, kwargs = mock_pad.call_args
if not needs_crop:
assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel
fill_sentinel = _convert_fill_arg(fill_sentinel)
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else:
mock_pad.assert_not_called()
def test__transform_culling(self, mocker): def test__transform_culling(self, mocker):
batch_size = 10 batch_size = 10
canvas_size = (10, 10) canvas_size = (10, 10)
......
...@@ -27,7 +27,7 @@ from common_utils import ( ...@@ -27,7 +27,7 @@ from common_utils import (
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints from torchvision import datapoints
from torchvision.ops.boxes import box_iou from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import InterpolationMode, to_pil_image from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw
...@@ -419,46 +419,6 @@ class TestPad: ...@@ -419,46 +419,6 @@ class TestPad:
with pytest.raises(ValueError, match="Padding mode should be either"): with pytest.raises(ValueError, match="Padding mode should be either"):
transforms.Pad(12, padding_mode="abc") transforms.Pad(12, padding_mode="abc")
@pytest.mark.parametrize("padding", [1, (1, 2), [1, 2, 3, 4]])
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
@pytest.mark.parametrize("padding_mode", ["constant", "edge"])
def test__transform(self, padding, fill, padding_mode, mocker):
transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode)
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
inpt = mocker.MagicMock(spec=datapoints.Image)
_ = transform(inpt)
fill = transforms._utils._convert_fill_arg(fill)
if isinstance(padding, tuple):
padding = list(padding)
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
@pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}])
def test__transform_image_mask(self, fill, mocker):
transform = transforms.Pad(1, fill=fill, padding_mode="constant")
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
image = datapoints.Image(torch.rand(3, 32, 32))
mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]
_ = transform(inpt)
if isinstance(fill, int):
fill = transforms._utils._convert_fill_arg(fill)
calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
]
else:
fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"),
]
fn.assert_has_calls(calls)
class TestRandomZoomOut: class TestRandomZoomOut:
def test_assertions(self): def test_assertions(self):
...@@ -487,56 +447,6 @@ class TestRandomZoomOut: ...@@ -487,56 +447,6 @@ class TestRandomZoomOut:
assert 0 <= params["padding"][2] <= (side_range[1] - 1) * w assert 0 <= params["padding"][2] <= (side_range[1] - 1) * w
assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
@pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
def test__transform(self, fill, side_range, mocker):
inpt = make_image((24, 32))
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1)
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params([inpt])
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill)
@pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}])
def test__transform_image_mask(self, fill, mocker):
transform = transforms.RandomZoomOut(fill=fill, p=1.0)
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
image = datapoints.Image(torch.rand(3, 32, 32))
mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)
if isinstance(fill, int):
fill = transforms._utils._convert_fill_arg(fill)
calls = [
mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=fill),
]
else:
fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, **params, fill=fill_img),
mocker.call(mask, **params, fill=fill_mask),
]
fn.assert_has_calls(calls)
class TestRandomCrop: class TestRandomCrop:
def test_assertions(self): def test_assertions(self):
...@@ -599,51 +509,6 @@ class TestRandomCrop: ...@@ -599,51 +509,6 @@ class TestRandomCrop:
assert params["needs_pad"] is any(padding) assert params["needs_pad"] is any(padding)
assert params["padding"] == padding assert params["padding"] == padding
@pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]])
@pytest.mark.parametrize("pad_if_needed", [False, True])
@pytest.mark.parametrize("fill", [False, True])
@pytest.mark.parametrize("padding_mode", ["constant", "edge"])
def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker):
output_size = [10, 12]
transform = transforms.RandomCrop(
output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode
)
h, w = size = (32, 32)
inpt = make_image(size)
if isinstance(padding, int):
new_size = (h + padding, w + padding)
elif isinstance(padding, list):
new_size = (h + sum(padding[0::2]), w + sum(padding[1::2]))
else:
new_size = size
expected = make_image(new_size)
_ = mocker.patch("torchvision.transforms.v2.functional.pad", return_value=expected)
fn_crop = mocker.patch("torchvision.transforms.v2.functional.crop")
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
params = transform._get_params([inpt])
if padding is None and not pad_if_needed:
fn_crop.assert_called_once_with(
inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
)
elif not pad_if_needed:
fn_crop.assert_called_once_with(
expected, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
)
elif padding is None:
# vfdev-5: I do not know how to mock and test this case
pass
else:
# vfdev-5: I do not know how to mock and test this case
pass
class TestGaussianBlur: class TestGaussianBlur:
def test_assertions(self): def test_assertions(self):
...@@ -675,62 +540,6 @@ class TestGaussianBlur: ...@@ -675,62 +540,6 @@ class TestGaussianBlur:
assert sigma[0] <= params["sigma"][0] <= sigma[1] assert sigma[0] <= params["sigma"][0] <= sigma[1]
assert sigma[0] <= params["sigma"][1] <= sigma[1] assert sigma[0] <= params["sigma"][1] <= sigma[1]
@pytest.mark.parametrize("kernel_size", [3, [3, 5], (5, 3)])
@pytest.mark.parametrize("sigma", [2.0, [2.0, 3.0]])
def test__transform(self, kernel_size, sigma, mocker):
transform = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma)
if isinstance(kernel_size, (tuple, list)):
assert transform.kernel_size == kernel_size
else:
kernel_size = (kernel_size, kernel_size)
assert transform.kernel_size == kernel_size
if isinstance(sigma, (tuple, list)):
assert transform.sigma == sigma
else:
assert transform.sigma == [sigma, sigma]
fn = mocker.patch("torchvision.transforms.v2.functional.gaussian_blur")
inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3
inpt.canvas_size = (24, 32)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
params = transform._get_params([inpt])
fn.assert_called_once_with(inpt, kernel_size, **params)
class TestRandomColorOp:
@pytest.mark.parametrize("p", [0.0, 1.0])
@pytest.mark.parametrize(
"transform_cls, func_op_name, kwargs",
[
(transforms.RandomEqualize, "equalize", {}),
(transforms.RandomInvert, "invert", {}),
(transforms.RandomAutocontrast, "autocontrast", {}),
(transforms.RandomPosterize, "posterize", {"bits": 4}),
(transforms.RandomSolarize, "solarize", {"threshold": 0.5}),
(transforms.RandomAdjustSharpness, "adjust_sharpness", {"sharpness_factor": 0.5}),
],
)
def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker):
transform = transform_cls(p=p, **kwargs)
fn = mocker.patch(f"torchvision.transforms.v2.functional.{func_op_name}")
inpt = mocker.MagicMock(spec=datapoints.Image)
_ = transform(inpt)
if p > 0.0:
fn.assert_called_once_with(inpt, **kwargs)
else:
assert fn.call_count == 0
class TestRandomPerspective: class TestRandomPerspective:
def test_assertions(self): def test_assertions(self):
...@@ -751,28 +560,6 @@ class TestRandomPerspective: ...@@ -751,28 +560,6 @@ class TestRandomPerspective:
assert "coefficients" in params assert "coefficients" in params
assert len(params["coefficients"]) == 8 assert len(params["coefficients"]) == 8
@pytest.mark.parametrize("distortion_scale", [0.1, 0.7])
def test__transform(self, distortion_scale, mocker):
interpolation = InterpolationMode.BILINEAR
fill = 12
transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation)
fn = mocker.patch("torchvision.transforms.v2.functional.perspective")
inpt = make_image((24, 32))
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params([inpt])
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, None, None, **params, fill=fill, interpolation=interpolation)
class TestElasticTransform: class TestElasticTransform:
def test_assertions(self): def test_assertions(self):
...@@ -813,35 +600,6 @@ class TestElasticTransform: ...@@ -813,35 +600,6 @@ class TestElasticTransform:
assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all() assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all()
assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all() assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all()
@pytest.mark.parametrize("alpha", [5.0, [5.0, 10.0]])
@pytest.mark.parametrize("sigma", [2.0, [2.0, 5.0]])
def test__transform(self, alpha, sigma, mocker):
interpolation = InterpolationMode.BILINEAR
fill = 12
transform = transforms.ElasticTransform(alpha, sigma=sigma, fill=fill, interpolation=interpolation)
if isinstance(alpha, float):
assert transform.alpha == [alpha, alpha]
else:
assert transform.alpha == alpha
if isinstance(sigma, float):
assert transform.sigma == [sigma, sigma]
else:
assert transform.sigma == sigma
fn = mocker.patch("torchvision.transforms.v2.functional.elastic")
inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3
inpt.canvas_size = (24, 32)
# Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock()
_ = transform(inpt)
params = transform._get_params([inpt])
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
class TestRandomErasing: class TestRandomErasing:
def test_assertions(self): def test_assertions(self):
...@@ -889,40 +647,6 @@ class TestRandomErasing: ...@@ -889,40 +647,6 @@ class TestRandomErasing:
assert 0 <= i <= height - h assert 0 <= i <= height - h
assert 0 <= j <= width - w assert 0 <= j <= width - w
@pytest.mark.parametrize("p", [0, 1])
def test__transform(self, mocker, p):
transform = transforms.RandomErasing(p=p)
transform._transformed_types = (mocker.MagicMock,)
i_sentinel = mocker.MagicMock()
j_sentinel = mocker.MagicMock()
h_sentinel = mocker.MagicMock()
w_sentinel = mocker.MagicMock()
v_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.transforms.v2._augment.RandomErasing._get_params",
return_value=dict(i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel),
)
inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.transforms.v2._augment.F.erase")
output = transform(inpt_sentinel)
if p:
mock.assert_called_once_with(
inpt_sentinel,
i=i_sentinel,
j=j_sentinel,
h=h_sentinel,
w=w_sentinel,
v=v_sentinel,
inplace=transform.inplace,
)
else:
mock.assert_not_called()
assert output is inpt_sentinel
class TestTransform: class TestTransform:
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1111,23 +835,12 @@ class TestRandomIoUCrop: ...@@ -1111,23 +835,12 @@ class TestRandomIoUCrop:
sample = [image, bboxes, masks] sample = [image, bboxes, masks]
fn = mocker.patch("torchvision.transforms.v2.functional.crop", side_effect=lambda x, **params: x)
is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool) is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool)
params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area)
transform._get_params = mocker.MagicMock(return_value=params) transform._get_params = mocker.MagicMock(return_value=params)
output = transform(sample) output = transform(sample)
assert fn.call_count == 3
expected_calls = [
mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
]
fn.assert_has_calls(expected_calls)
# check number of bboxes vs number of labels: # check number of bboxes vs number of labels:
output_bboxes = output[1] output_bboxes = output[1]
assert isinstance(output_bboxes, datapoints.BoundingBoxes) assert isinstance(output_bboxes, datapoints.BoundingBoxes)
...@@ -1164,29 +877,6 @@ class TestScaleJitter: ...@@ -1164,29 +877,6 @@ class TestScaleJitter:
assert int(canvas_size[0] * r_min) <= height <= int(canvas_size[0] * r_max) assert int(canvas_size[0] * r_min) <= height <= int(canvas_size[0] * r_max)
assert int(canvas_size[1] * r_min) <= width <= int(canvas_size[1] * r_max) assert int(canvas_size[1] * r_min) <= width <= int(canvas_size[1] * r_max)
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()
transform = transforms.ScaleJitter(
target_size=(16, 12), interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.transforms.v2._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel)
)
inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel)
mock.assert_called_once_with(
inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
class TestRandomShortestSize: class TestRandomShortestSize:
@pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)]) @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
...@@ -1211,30 +901,6 @@ class TestRandomShortestSize: ...@@ -1211,30 +901,6 @@ class TestRandomShortestSize:
else: else:
assert shorter in min_size assert shorter in min_size
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()
transform = transforms.RandomShortestSize(
min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.transforms.v2._geometry.RandomShortestSize._get_params",
return_value=dict(size=size_sentinel),
)
inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel)
mock.assert_called_once_with(
inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
class TestLinearTransformation: class TestLinearTransformation:
def test_assertions(self): def test_assertions(self):
...@@ -1260,7 +926,7 @@ class TestLinearTransformation: ...@@ -1260,7 +926,7 @@ class TestLinearTransformation:
transform = transforms.LinearTransformation(m, v) transform = transforms.LinearTransformation(m, v)
if isinstance(inpt, PIL.Image.Image): if isinstance(inpt, PIL.Image.Image):
with pytest.raises(TypeError, match="LinearTransformation does not work on PIL Images"): with pytest.raises(TypeError, match="does not support PIL images"):
transform(inpt) transform(inpt)
else: else:
output = transform(inpt) output = transform(inpt)
...@@ -1284,30 +950,6 @@ class TestRandomResize: ...@@ -1284,30 +950,6 @@ class TestRandomResize:
assert min_size <= size < max_size assert min_size <= size < max_size
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()
transform = transforms.RandomResize(
min_size=-1, max_size=-1, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.transforms.v2._geometry.RandomResize._get_params",
return_value=dict(size=size_sentinel),
)
inpt_sentinel = mocker.MagicMock()
mock_resize = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel)
mock_resize.assert_called_with(
inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
class TestUniformTemporalSubsample: class TestUniformTemporalSubsample:
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -1259,68 +1259,6 @@ class TestRefSegTransforms: ...@@ -1259,68 +1259,6 @@ class TestRefSegTransforms:
def test_common(self, t_ref, t, data_kwargs): def test_common(self, t_ref, t, data_kwargs):
self.check(t, t_ref, data_kwargs) self.check(t, t_ref, data_kwargs)
def check_resize(self, mocker, t_ref, t):
mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
mock_ref = mocker.patch("torchvision.transforms.functional.resize")
for dp, dp_ref in self.make_datapoints():
mock.reset_mock()
mock_ref.reset_mock()
self.set_seed()
t(dp)
assert mock.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp)
)
self.set_seed()
t_ref(*dp_ref)
assert mock_ref.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref)
)
for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list):
assert args_kwargs[0][1] == [args_kwargs_ref[0][1]]
def test_random_resize_train(self, mocker):
base_size = 520
min_size = base_size // 2
max_size = base_size * 2
randint = torch.randint
def patched_randint(a, b, *other_args, **kwargs):
if kwargs or len(other_args) > 1 or other_args[0] != ():
return randint(a, b, *other_args, **kwargs)
return random.randint(a, b)
# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
# normally
t = v2_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
mocker.patch(
"torchvision.transforms.v2._geometry.torch.randint",
new=patched_randint,
)
t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size)
self.check_resize(mocker, t_ref, t)
def test_random_resize_eval(self, mocker):
torch.manual_seed(0)
base_size = 520
t = v2_transforms.Resize(size=base_size, antialias=True)
t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size)
self.check_resize(mocker, t_ref, t)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("legacy_dispatcher", "name_only_params"), ("legacy_dispatcher", "name_only_params"),
......
...@@ -39,7 +39,7 @@ from torchvision import datapoints ...@@ -39,7 +39,7 @@ from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.functional import pil_modes_mapping
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -376,35 +376,6 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -376,35 +376,6 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
return torch.stack([transform(b) for b in bounding_boxes.reshape(-1, 4).unbind()]).reshape(bounding_boxes.shape) return torch.stack([transform(b) for b in bounding_boxes.reshape(-1, 4).unbind()]).reshape(bounding_boxes.shape)
@pytest.mark.parametrize(
("dispatcher", "registered_input_types"),
[(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()],
)
def test_exhaustive_kernel_registration(dispatcher, registered_input_types):
missing = {
torch.Tensor,
PIL.Image.Image,
datapoints.Image,
datapoints.BoundingBoxes,
datapoints.Mask,
datapoints.Video,
} - registered_input_types
if missing:
names = sorted(str(t) for t in missing)
raise AssertionError(
"\n".join(
[
f"The dispatcher '{dispatcher.__name__}' has no kernel registered for",
"",
*[f"- {name}" for name in names],
"",
f"If available, register the kernels with @_register_kernel_internal({dispatcher.__name__}, ...).",
f"If not, register explicit no-ops with @_register_explicit_noop({', '.join(names)})",
]
)
)
class TestResize: class TestResize:
INPUT_SIZE = (17, 11) INPUT_SIZE = (17, 11)
OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)] OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)]
...@@ -2128,9 +2099,20 @@ class TestRegisterKernel: ...@@ -2128,9 +2099,20 @@ class TestRegisterKernel:
with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"): with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
F.register_kernel(F.resize, object) F.register_kernel(F.resize, object)
with pytest.raises(ValueError, match="already has a kernel registered for type"): with pytest.raises(ValueError, match="cannot be registered for the builtin datapoint classes"):
F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor)
class CustomDatapoint(datapoints.Datapoint):
pass
def resize_custom_datapoint():
pass
F.register_kernel(F.resize, CustomDatapoint)(resize_custom_datapoint)
with pytest.raises(ValueError, match="already has a kernel registered for type"):
F.register_kernel(F.resize, CustomDatapoint)(resize_custom_datapoint)
class TestGetKernel: class TestGetKernel:
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination # We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination
...@@ -2152,13 +2134,7 @@ class TestGetKernel: ...@@ -2152,13 +2134,7 @@ class TestGetKernel:
pass pass
for input_type in [str, int, object, MyTensor, MyPILImage]: for input_type in [str, int, object, MyTensor, MyPILImage]:
with pytest.raises( with pytest.raises(TypeError, match="supports inputs of type"):
TypeError,
match=(
"supports inputs of type torch.Tensor, PIL.Image.Image, "
"and subclasses of torchvision.datapoints.Datapoint"
),
):
_get_kernel(F.resize, input_type) _get_kernel(F.resize, input_type)
def test_exact_match(self): def test_exact_match(self):
...@@ -2211,8 +2187,8 @@ class TestGetKernel: ...@@ -2211,8 +2187,8 @@ class TestGetKernel:
class MyDatapoint(datapoints.Datapoint): class MyDatapoint(datapoints.Datapoint):
pass pass
# Note that this will be an error in the future with pytest.raises(TypeError, match="supports inputs of type"):
assert _get_kernel(F.resize, MyDatapoint) is _noop _get_kernel(F.resize, MyDatapoint)
def resize_my_datapoint(): def resize_my_datapoint():
pass pass
......
...@@ -101,7 +101,8 @@ class FixedSizeCrop(Transform): ...@@ -101,7 +101,8 @@ class FixedSizeCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_crop"]: if params["needs_crop"]:
inpt = F.crop( inpt = self._call_kernel(
F.crop,
inpt, inpt,
top=params["top"], top=params["top"],
left=params["left"], left=params["left"],
...@@ -120,6 +121,6 @@ class FixedSizeCrop(Transform): ...@@ -120,6 +121,6 @@ class FixedSizeCrop(Transform):
if params["needs_pad"]: if params["needs_pad"]:
fill = _get_fill(self._fill, type(inpt)) fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) inpt = self._call_kernel(F.pad, inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt return inpt
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any, Dict, List, Tuple from typing import Any, Callable, Dict, List, Tuple
import PIL.Image import PIL.Image
import torch import torch
...@@ -91,6 +91,14 @@ class RandomErasing(_RandomApplyTransform): ...@@ -91,6 +91,14 @@ class RandomErasing(_RandomApplyTransform):
self._log_ratio = torch.log(torch.tensor(self.ratio)) self._log_ratio = torch.log(torch.tensor(self.ratio))
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs) img_c, img_h, img_w = query_chw(flat_inputs)
...@@ -131,7 +139,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -131,7 +139,7 @@ class RandomErasing(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["v"] is not None: if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace) inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace)
return inpt return inpt
......
import collections.abc import collections.abc
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch import torch
from torchvision import datapoints, transforms as _transforms from torchvision import transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from .utils import is_simple_tensor, query_chw from .utils import query_chw
class Grayscale(Transform): class Grayscale(Transform):
...@@ -24,19 +23,12 @@ class Grayscale(Transform): ...@@ -24,19 +23,12 @@ class Grayscale(Transform):
_v1_transform_cls = _transforms.Grayscale _v1_transform_cls = _transforms.Grayscale
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, num_output_channels: int = 1): def __init__(self, num_output_channels: int = 1):
super().__init__() super().__init__()
self.num_output_channels = num_output_channels self.num_output_channels = num_output_channels
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels)
class RandomGrayscale(_RandomApplyTransform): class RandomGrayscale(_RandomApplyTransform):
...@@ -55,13 +47,6 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -55,13 +47,6 @@ class RandomGrayscale(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomGrayscale _v1_transform_cls = _transforms.RandomGrayscale
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, p: float = 0.1) -> None: def __init__(self, p: float = 0.1) -> None:
super().__init__(p=p) super().__init__(p=p)
...@@ -70,7 +55,7 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -70,7 +55,7 @@ class RandomGrayscale(_RandomApplyTransform):
return dict(num_input_channels=num_input_channels) return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"])
class ColorJitter(Transform): class ColorJitter(Transform):
...@@ -167,13 +152,13 @@ class ColorJitter(Transform): ...@@ -167,13 +152,13 @@ class ColorJitter(Transform):
hue_factor = params["hue_factor"] hue_factor = params["hue_factor"]
for fn_id in params["fn_idx"]: for fn_id in params["fn_idx"]:
if fn_id == 0 and brightness_factor is not None: if fn_id == 0 and brightness_factor is not None:
output = F.adjust_brightness(output, brightness_factor=brightness_factor) output = self._call_kernel(F.adjust_brightness, output, brightness_factor=brightness_factor)
elif fn_id == 1 and contrast_factor is not None: elif fn_id == 1 and contrast_factor is not None:
output = F.adjust_contrast(output, contrast_factor=contrast_factor) output = self._call_kernel(F.adjust_contrast, output, contrast_factor=contrast_factor)
elif fn_id == 2 and saturation_factor is not None: elif fn_id == 2 and saturation_factor is not None:
output = F.adjust_saturation(output, saturation_factor=saturation_factor) output = self._call_kernel(F.adjust_saturation, output, saturation_factor=saturation_factor)
elif fn_id == 3 and hue_factor is not None: elif fn_id == 3 and hue_factor is not None:
output = F.adjust_hue(output, hue_factor=hue_factor) output = self._call_kernel(F.adjust_hue, output, hue_factor=hue_factor)
return output return output
...@@ -183,19 +168,12 @@ class RandomChannelPermutation(Transform): ...@@ -183,19 +168,12 @@ class RandomChannelPermutation(Transform):
.. v2betastatus:: RandomChannelPermutation transform .. v2betastatus:: RandomChannelPermutation transform
""" """
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs) num_channels, *_ = query_chw(flat_inputs)
return dict(permutation=torch.randperm(num_channels)) return dict(permutation=torch.randperm(num_channels))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.permute_channels(inpt, params["permutation"]) return self._call_kernel(F.permute_channels, inpt, params["permutation"])
class RandomPhotometricDistort(Transform): class RandomPhotometricDistort(Transform):
...@@ -224,13 +202,6 @@ class RandomPhotometricDistort(Transform): ...@@ -224,13 +202,6 @@ class RandomPhotometricDistort(Transform):
Default is 0.5. Default is 0.5.
""" """
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__( def __init__(
self, self,
brightness: Tuple[float, float] = (0.875, 1.125), brightness: Tuple[float, float] = (0.875, 1.125),
...@@ -263,17 +234,17 @@ class RandomPhotometricDistort(Transform): ...@@ -263,17 +234,17 @@ class RandomPhotometricDistort(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["brightness_factor"] is not None: if params["brightness_factor"] is not None:
inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"]) inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"])
if params["contrast_factor"] is not None and params["contrast_before"]: if params["contrast_factor"] is not None and params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
if params["saturation_factor"] is not None: if params["saturation_factor"] is not None:
inpt = F.adjust_saturation(inpt, saturation_factor=params["saturation_factor"]) inpt = self._call_kernel(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"])
if params["hue_factor"] is not None: if params["hue_factor"] is not None:
inpt = F.adjust_hue(inpt, hue_factor=params["hue_factor"]) inpt = self._call_kernel(F.adjust_hue, inpt, hue_factor=params["hue_factor"])
if params["contrast_factor"] is not None and not params["contrast_before"]: if params["contrast_factor"] is not None and not params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
if params["channel_permutation"] is not None: if params["channel_permutation"] is not None:
inpt = F.permute_channels(inpt, permutation=params["channel_permutation"]) inpt = self._call_kernel(F.permute_channels, inpt, permutation=params["channel_permutation"])
return inpt return inpt
...@@ -293,7 +264,7 @@ class RandomEqualize(_RandomApplyTransform): ...@@ -293,7 +264,7 @@ class RandomEqualize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomEqualize _v1_transform_cls = _transforms.RandomEqualize
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt) return self._call_kernel(F.equalize, inpt)
class RandomInvert(_RandomApplyTransform): class RandomInvert(_RandomApplyTransform):
...@@ -312,7 +283,7 @@ class RandomInvert(_RandomApplyTransform): ...@@ -312,7 +283,7 @@ class RandomInvert(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomInvert _v1_transform_cls = _transforms.RandomInvert
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.invert(inpt) return self._call_kernel(F.invert, inpt)
class RandomPosterize(_RandomApplyTransform): class RandomPosterize(_RandomApplyTransform):
...@@ -337,7 +308,7 @@ class RandomPosterize(_RandomApplyTransform): ...@@ -337,7 +308,7 @@ class RandomPosterize(_RandomApplyTransform):
self.bits = bits self.bits = bits
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.posterize(inpt, bits=self.bits) return self._call_kernel(F.posterize, inpt, bits=self.bits)
class RandomSolarize(_RandomApplyTransform): class RandomSolarize(_RandomApplyTransform):
...@@ -362,7 +333,7 @@ class RandomSolarize(_RandomApplyTransform): ...@@ -362,7 +333,7 @@ class RandomSolarize(_RandomApplyTransform):
self.threshold = threshold self.threshold = threshold
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.solarize(inpt, threshold=self.threshold) return self._call_kernel(F.solarize, inpt, threshold=self.threshold)
class RandomAutocontrast(_RandomApplyTransform): class RandomAutocontrast(_RandomApplyTransform):
...@@ -381,7 +352,7 @@ class RandomAutocontrast(_RandomApplyTransform): ...@@ -381,7 +352,7 @@ class RandomAutocontrast(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomAutocontrast _v1_transform_cls = _transforms.RandomAutocontrast
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.autocontrast(inpt) return self._call_kernel(F.autocontrast, inpt)
class RandomAdjustSharpness(_RandomApplyTransform): class RandomAdjustSharpness(_RandomApplyTransform):
...@@ -406,4 +377,4 @@ class RandomAdjustSharpness(_RandomApplyTransform): ...@@ -406,4 +377,4 @@ class RandomAdjustSharpness(_RandomApplyTransform):
self.sharpness_factor = sharpness_factor self.sharpness_factor = sharpness_factor
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor) return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor)
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union from typing import Any, Callable, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -44,7 +44,7 @@ class RandomHorizontalFlip(_RandomApplyTransform): ...@@ -44,7 +44,7 @@ class RandomHorizontalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomHorizontalFlip _v1_transform_cls = _transforms.RandomHorizontalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.horizontal_flip(inpt) return self._call_kernel(F.horizontal_flip, inpt)
class RandomVerticalFlip(_RandomApplyTransform): class RandomVerticalFlip(_RandomApplyTransform):
...@@ -64,7 +64,7 @@ class RandomVerticalFlip(_RandomApplyTransform): ...@@ -64,7 +64,7 @@ class RandomVerticalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomVerticalFlip _v1_transform_cls = _transforms.RandomVerticalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.vertical_flip(inpt) return self._call_kernel(F.vertical_flip, inpt)
class Resize(Transform): class Resize(Transform):
...@@ -152,7 +152,8 @@ class Resize(Transform): ...@@ -152,7 +152,8 @@ class Resize(Transform):
self.antialias = antialias self.antialias = antialias
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize( return self._call_kernel(
F.resize,
inpt, inpt,
self.size, self.size,
interpolation=self.interpolation, interpolation=self.interpolation,
...@@ -186,7 +187,7 @@ class CenterCrop(Transform): ...@@ -186,7 +187,7 @@ class CenterCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.center_crop(inpt, output_size=self.size) return self._call_kernel(F.center_crop, inpt, output_size=self.size)
class RandomResizedCrop(Transform): class RandomResizedCrop(Transform):
...@@ -307,8 +308,8 @@ class RandomResizedCrop(Transform): ...@@ -307,8 +308,8 @@ class RandomResizedCrop(Transform):
return dict(top=i, left=j, height=h, width=w) return dict(top=i, left=j, height=h, width=w)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resized_crop( return self._call_kernel(
inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
) )
...@@ -357,8 +358,16 @@ class FiveCrop(Transform): ...@@ -357,8 +358,16 @@ class FiveCrop(Transform):
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.five_crop(inpt, self.size) return self._call_kernel(F.five_crop, inpt, self.size)
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask):
...@@ -396,12 +405,20 @@ class TenCrop(Transform): ...@@ -396,12 +405,20 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask):
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) return self._call_kernel(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip)
class Pad(Transform): class Pad(Transform):
...@@ -475,7 +492,7 @@ class Pad(Transform): ...@@ -475,7 +492,7 @@ class Pad(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt)) fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
class RandomZoomOut(_RandomApplyTransform): class RandomZoomOut(_RandomApplyTransform):
...@@ -545,7 +562,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -545,7 +562,7 @@ class RandomZoomOut(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt)) fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, **params, fill=fill) return self._call_kernel(F.pad, inpt, **params, fill=fill)
class RandomRotation(Transform): class RandomRotation(Transform):
...@@ -611,7 +628,8 @@ class RandomRotation(Transform): ...@@ -611,7 +628,8 @@ class RandomRotation(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt)) fill = _get_fill(self._fill, type(inpt))
return F.rotate( return self._call_kernel(
F.rotate,
inpt, inpt,
**params, **params,
interpolation=self.interpolation, interpolation=self.interpolation,
...@@ -733,7 +751,8 @@ class RandomAffine(Transform): ...@@ -733,7 +751,8 @@ class RandomAffine(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt)) fill = _get_fill(self._fill, type(inpt))
return F.affine( return self._call_kernel(
F.affine,
inpt, inpt,
**params, **params,
interpolation=self.interpolation, interpolation=self.interpolation,
...@@ -889,10 +908,12 @@ class RandomCrop(Transform): ...@@ -889,10 +908,12 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]: if params["needs_pad"]:
fill = _get_fill(self._fill, type(inpt)) fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) inpt = self._call_kernel(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]: if params["needs_crop"]:
inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) inpt = self._call_kernel(
F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
)
return inpt return inpt
...@@ -973,7 +994,8 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -973,7 +994,8 @@ class RandomPerspective(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt)) fill = _get_fill(self._fill, type(inpt))
return F.perspective( return self._call_kernel(
F.perspective,
inpt, inpt,
None, None,
None, None,
...@@ -1050,7 +1072,7 @@ class ElasticTransform(Transform): ...@@ -1050,7 +1072,7 @@ class ElasticTransform(Transform):
# if kernel size is even we have to make it odd # if kernel size is even we have to make it odd
if kx % 2 == 0: if kx % 2 == 0:
kx += 1 kx += 1
dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma)) dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma))
dx = dx * self.alpha[0] / size[0] dx = dx * self.alpha[0] / size[0]
dy = torch.rand([1, 1] + size) * 2 - 1 dy = torch.rand([1, 1] + size) * 2 - 1
...@@ -1059,14 +1081,15 @@ class ElasticTransform(Transform): ...@@ -1059,14 +1081,15 @@ class ElasticTransform(Transform):
# if kernel size is even we have to make it odd # if kernel size is even we have to make it odd
if ky % 2 == 0: if ky % 2 == 0:
ky += 1 ky += 1
dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma)) dy = self._call_kernel(F.gaussian_blur, dy, [ky, ky], list(self.sigma))
dy = dy * self.alpha[1] / size[1] dy = dy * self.alpha[1] / size[1]
displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
return dict(displacement=displacement) return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt)) fill = _get_fill(self._fill, type(inpt))
return F.elastic( return self._call_kernel(
F.elastic,
inpt, inpt,
**params, **params,
fill=fill, fill=fill,
...@@ -1164,7 +1187,9 @@ class RandomIoUCrop(Transform): ...@@ -1164,7 +1187,9 @@ class RandomIoUCrop(Transform):
# check for any valid boxes with centers within the crop area # check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_boxes( xyxy_bboxes = F.convert_format_bounding_boxes(
bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY bboxes.as_subclass(torch.Tensor),
bboxes.format,
datapoints.BoundingBoxFormat.XYXY,
) )
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
...@@ -1188,7 +1213,9 @@ class RandomIoUCrop(Transform): ...@@ -1188,7 +1213,9 @@ class RandomIoUCrop(Transform):
if len(params) < 1: if len(params) < 1:
return inpt return inpt
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) output = self._call_kernel(
F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
)
if isinstance(output, datapoints.BoundingBoxes): if isinstance(output, datapoints.BoundingBoxes):
# We "mark" the invalid boxes as degenreate, and they can be # We "mark" the invalid boxes as degenreate, and they can be
...@@ -1262,7 +1289,9 @@ class ScaleJitter(Transform): ...@@ -1262,7 +1289,9 @@ class ScaleJitter(Transform):
return dict(size=(new_height, new_width)) return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) return self._call_kernel(
F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias
)
class RandomShortestSize(Transform): class RandomShortestSize(Transform):
...@@ -1330,7 +1359,9 @@ class RandomShortestSize(Transform): ...@@ -1330,7 +1359,9 @@ class RandomShortestSize(Transform):
return dict(size=(new_height, new_width)) return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) return self._call_kernel(
F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias
)
class RandomResize(Transform): class RandomResize(Transform):
...@@ -1400,4 +1431,6 @@ class RandomResize(Transform): ...@@ -1400,4 +1431,6 @@ class RandomResize(Transform):
return dict(size=[size]) return dict(size=[size])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias) return self._call_kernel(
F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias
)
...@@ -106,7 +106,7 @@ class LinearTransformation(Transform): ...@@ -106,7 +106,7 @@ class LinearTransformation(Transform):
def _check_inputs(self, sample: Any) -> Any: def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image): if has_any(sample, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images") raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
shape = inpt.shape shape = inpt.shape
...@@ -157,7 +157,6 @@ class Normalize(Transform): ...@@ -157,7 +157,6 @@ class Normalize(Transform):
""" """
_v1_transform_cls = _transforms.Normalize _v1_transform_cls = _transforms.Normalize
_transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__() super().__init__()
...@@ -170,7 +169,7 @@ class Normalize(Transform): ...@@ -170,7 +169,7 @@ class Normalize(Transform):
raise TypeError(f"{type(self).__name__}() does not support PIL images.") raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) return self._call_kernel(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace)
class GaussianBlur(Transform): class GaussianBlur(Transform):
...@@ -217,7 +216,7 @@ class GaussianBlur(Transform): ...@@ -217,7 +216,7 @@ class GaussianBlur(Transform):
return dict(sigma=[sigma, sigma]) return dict(sigma=[sigma, sigma])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.gaussian_blur(inpt, self.kernel_size, **params) return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params)
class ToDtype(Transform): class ToDtype(Transform):
...@@ -290,7 +289,7 @@ class ToDtype(Transform): ...@@ -290,7 +289,7 @@ class ToDtype(Transform):
) )
return inpt return inpt
return F.to_dtype(inpt, dtype=dtype, scale=self.scale) return self._call_kernel(F.to_dtype, inpt, dtype=dtype, scale=self.scale)
class ConvertImageDtype(Transform): class ConvertImageDtype(Transform):
...@@ -320,14 +319,12 @@ class ConvertImageDtype(Transform): ...@@ -320,14 +319,12 @@ class ConvertImageDtype(Transform):
_v1_transform_cls = _transforms.ConvertImageDtype _v1_transform_cls = _transforms.ConvertImageDtype
_transformed_types = (is_simple_tensor, datapoints.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None: def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.to_dtype(inpt, dtype=self.dtype, scale=True) return self._call_kernel(F.to_dtype, inpt, dtype=self.dtype, scale=True)
class SanitizeBoundingBoxes(Transform): class SanitizeBoundingBoxes(Transform):
......
...@@ -25,4 +25,4 @@ class UniformTemporalSubsample(Transform): ...@@ -25,4 +25,4 @@ class UniformTemporalSubsample(Transform):
self.num_samples = num_samples self.num_samples = num_samples
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.uniform_temporal_subsample(inpt, self.num_samples) return self._call_kernel(F.uniform_temporal_subsample, inpt, self.num_samples)
...@@ -11,6 +11,8 @@ from torchvision import datapoints ...@@ -11,6 +11,8 @@ from torchvision import datapoints
from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from .functional._utils import _get_kernel
class Transform(nn.Module): class Transform(nn.Module):
...@@ -28,6 +30,10 @@ class Transform(nn.Module): ...@@ -28,6 +30,10 @@ class Transform(nn.Module):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict() return dict()
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
kernel = _get_kernel(dispatcher, type(inpt), allow_passthrough=True)
return kernel(inpt, *args, **kwargs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
raise NotImplementedError raise NotImplementedError
......
...@@ -5,10 +5,9 @@ from torchvision import datapoints ...@@ -5,10 +5,9 @@ from torchvision import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal from ._utils import _get_kernel, _register_kernel_internal
@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
def erase( def erase(
inpt: torch.Tensor, inpt: torch.Tensor,
i: int, i: int,
......
...@@ -10,12 +10,10 @@ from torchvision.transforms._functional_tensor import _max_value ...@@ -10,12 +10,10 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._misc import _num_value_bits, to_dtype_image_tensor from ._misc import _num_value_bits, to_dtype_image_tensor
from ._type_conversion import pil_to_tensor, to_image_pil from ._type_conversion import pil_to_tensor, to_image_pil
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal from ._utils import _get_kernel, _register_kernel_internal
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video)
def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels)
...@@ -70,8 +68,8 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te ...@@ -70,8 +68,8 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
return output if fp else output.to(image1.dtype) return output if fp else output.to(image1.dtype)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
...@@ -107,7 +105,6 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to ...@@ -107,7 +105,6 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
...@@ -146,7 +143,6 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to ...@@ -146,7 +143,6 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
...@@ -185,7 +181,6 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. ...@@ -185,7 +181,6 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
...@@ -258,7 +253,6 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc ...@@ -258,7 +253,6 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
...@@ -370,7 +364,6 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: ...@@ -370,7 +364,6 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image_tensor(video, hue_factor=hue_factor) return adjust_hue_image_tensor(video, hue_factor=hue_factor)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
...@@ -410,7 +403,6 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to ...@@ -410,7 +403,6 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return posterize_image_tensor(inpt, bits=bits) return posterize_image_tensor(inpt, bits=bits)
...@@ -444,7 +436,6 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -444,7 +436,6 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image_tensor(video, bits=bits) return posterize_image_tensor(video, bits=bits)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor: def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return solarize_image_tensor(inpt, threshold=threshold) return solarize_image_tensor(inpt, threshold=threshold)
...@@ -472,7 +463,6 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: ...@@ -472,7 +463,6 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image_tensor(video, threshold=threshold) return solarize_image_tensor(video, threshold=threshold)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def autocontrast(inpt: torch.Tensor) -> torch.Tensor: def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return autocontrast_image_tensor(inpt) return autocontrast_image_tensor(inpt)
...@@ -522,7 +512,6 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: ...@@ -522,7 +512,6 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image_tensor(video) return autocontrast_image_tensor(video)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def equalize(inpt: torch.Tensor) -> torch.Tensor: def equalize(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return equalize_image_tensor(inpt) return equalize_image_tensor(inpt)
...@@ -612,7 +601,6 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: ...@@ -612,7 +601,6 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video) return equalize_image_tensor(video)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def invert(inpt: torch.Tensor) -> torch.Tensor: def invert(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return invert_image_tensor(inpt) return invert_image_tensor(inpt)
...@@ -643,7 +631,6 @@ def invert_video(video: torch.Tensor) -> torch.Tensor: ...@@ -643,7 +631,6 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video) return invert_image_tensor(video)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor: def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor:
"""Permute the channels of the input according to the given permutation. """Permute the channels of the input according to the given permutation.
......
...@@ -25,13 +25,7 @@ from torchvision.utils import _log_api_usage_once ...@@ -25,13 +25,7 @@ from torchvision.utils import _log_api_usage_once
from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil
from ._utils import ( from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
_FillTypeJIT,
_get_kernel,
_register_explicit_noop,
_register_five_ten_crop_kernel,
_register_kernel_internal,
)
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
...@@ -2203,7 +2197,6 @@ def resized_crop_video( ...@@ -2203,7 +2197,6 @@ def resized_crop_video(
) )
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True)
def five_crop( def five_crop(
inpt: torch.Tensor, size: List[int] inpt: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
...@@ -2230,8 +2223,8 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: ...@@ -2230,8 +2223,8 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
return size return size
@_register_five_ten_crop_kernel(five_crop, torch.Tensor) @_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor)
@_register_five_ten_crop_kernel(five_crop, datapoints.Image) @_register_five_ten_crop_kernel_internal(five_crop, datapoints.Image)
def five_crop_image_tensor( def five_crop_image_tensor(
image: torch.Tensor, size: List[int] image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
...@@ -2250,7 +2243,7 @@ def five_crop_image_tensor( ...@@ -2250,7 +2243,7 @@ def five_crop_image_tensor(
return tl, tr, bl, br, center return tl, tr, bl, br, center
@_register_five_ten_crop_kernel(five_crop, PIL.Image.Image) @_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image)
def five_crop_image_pil( def five_crop_image_pil(
image: PIL.Image.Image, size: List[int] image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
...@@ -2269,14 +2262,13 @@ def five_crop_image_pil( ...@@ -2269,14 +2262,13 @@ def five_crop_image_pil(
return tl, tr, bl, br, center return tl, tr, bl, br, center
@_register_five_ten_crop_kernel(five_crop, datapoints.Video) @_register_five_ten_crop_kernel_internal(five_crop, datapoints.Video)
def five_crop_video( def five_crop_video(
video: torch.Tensor, size: List[int] video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return five_crop_image_tensor(video, size) return five_crop_image_tensor(video, size)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True)
def ten_crop( def ten_crop(
inpt: torch.Tensor, size: List[int], vertical_flip: bool = False inpt: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
...@@ -2300,8 +2292,8 @@ def ten_crop( ...@@ -2300,8 +2292,8 @@ def ten_crop(
return kernel(inpt, size=size, vertical_flip=vertical_flip) return kernel(inpt, size=size, vertical_flip=vertical_flip)
@_register_five_ten_crop_kernel(ten_crop, torch.Tensor) @_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor)
@_register_five_ten_crop_kernel(ten_crop, datapoints.Image) @_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Image)
def ten_crop_image_tensor( def ten_crop_image_tensor(
image: torch.Tensor, size: List[int], vertical_flip: bool = False image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
...@@ -2328,7 +2320,7 @@ def ten_crop_image_tensor( ...@@ -2328,7 +2320,7 @@ def ten_crop_image_tensor(
return non_flipped + flipped return non_flipped + flipped
@_register_five_ten_crop_kernel(ten_crop, PIL.Image.Image) @_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image)
def ten_crop_image_pil( def ten_crop_image_pil(
image: PIL.Image.Image, size: List[int], vertical_flip: bool = False image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
...@@ -2355,7 +2347,7 @@ def ten_crop_image_pil( ...@@ -2355,7 +2347,7 @@ def ten_crop_image_pil(
return non_flipped + flipped return non_flipped + flipped
@_register_five_ten_crop_kernel(ten_crop, datapoints.Video) @_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Video)
def ten_crop_video( def ten_crop_video(
video: torch.Tensor, size: List[int], vertical_flip: bool = False video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
......
...@@ -8,10 +8,9 @@ from torchvision.transforms import _functional_pil as _FP ...@@ -8,10 +8,9 @@ from torchvision.transforms import _functional_pil as _FP
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_kernel_internal, _register_unsupported_type, is_simple_tensor from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_dimensions(inpt: torch.Tensor) -> List[int]: def get_dimensions(inpt: torch.Tensor) -> List[int]:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_dimensions_image_tensor(inpt) return get_dimensions_image_tensor(inpt)
...@@ -44,7 +43,6 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]: ...@@ -44,7 +43,6 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
return get_dimensions_image_tensor(video) return get_dimensions_image_tensor(video)
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_num_channels(inpt: torch.Tensor) -> int: def get_num_channels(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_num_channels_image_tensor(inpt) return get_num_channels_image_tensor(inpt)
...@@ -123,7 +121,6 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int] ...@@ -123,7 +121,6 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
return list(bounding_box.canvas_size) return list(bounding_box.canvas_size)
@_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)
def get_num_frames(inpt: torch.Tensor) -> int: def get_num_frames(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return get_num_frames_video(inpt) return get_num_frames_video(inpt)
......
...@@ -11,11 +11,9 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image ...@@ -11,11 +11,9 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, _register_unsupported_type from ._utils import _get_kernel, _register_kernel_internal
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
@_register_unsupported_type(PIL.Image.Image)
def normalize( def normalize(
inpt: torch.Tensor, inpt: torch.Tensor,
mean: List[float], mean: List[float],
...@@ -73,7 +71,6 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in ...@@ -73,7 +71,6 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
return normalize_image_tensor(video, mean, std, inplace=inplace) return normalize_image_tensor(video, mean, std, inplace=inplace)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor: def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
...@@ -182,7 +179,6 @@ def gaussian_blur_video( ...@@ -182,7 +179,6 @@ def gaussian_blur_video(
return gaussian_blur_image_tensor(video, kernel_size, sigma) return gaussian_blur_image_tensor(video, kernel_size, sigma)
@_register_unsupported_type(PIL.Image.Image)
def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale) return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale)
......
import PIL.Image
import torch import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal from ._utils import _get_kernel, _register_kernel_internal
@_register_explicit_noop(
PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True
)
def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Tensor: def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Tensor:
if torch.jit.is_scripting(): if torch.jit.is_scripting():
return uniform_temporal_subsample_video(inpt, num_samples=num_samples) return uniform_temporal_subsample_video(inpt, num_samples=num_samples)
......
import functools import functools
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
import torch import torch
...@@ -53,6 +52,11 @@ def _name_to_dispatcher(name): ...@@ -53,6 +52,11 @@ def _name_to_dispatcher(name):
) from None ) from None
_BUILTIN_DATAPOINT_TYPES = {
obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, datapoints.Datapoint)
}
def register_kernel(dispatcher, datapoint_cls): def register_kernel(dispatcher, datapoint_cls):
"""Decorate a kernel to register it for a dispatcher and a (custom) datapoint type. """Decorate a kernel to register it for a dispatcher and a (custom) datapoint type.
...@@ -70,20 +74,19 @@ def register_kernel(dispatcher, datapoint_cls): ...@@ -70,20 +74,19 @@ def register_kernel(dispatcher, datapoint_cls):
f"but got {dispatcher}." f"but got {dispatcher}."
) )
if not ( if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)):
isinstance(datapoint_cls, type)
and issubclass(datapoint_cls, datapoints.Datapoint)
and datapoint_cls is not datapoints.Datapoint
):
raise ValueError( raise ValueError(
f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, " f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
f"but got {datapoint_cls}." f"but got {datapoint_cls}."
) )
if datapoint_cls in _BUILTIN_DATAPOINT_TYPES:
raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}")
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
def _get_kernel(dispatcher, input_type): def _get_kernel(dispatcher, input_type, *, allow_passthrough=False):
registry = _KERNEL_REGISTRY.get(dispatcher) registry = _KERNEL_REGISTRY.get(dispatcher)
if not registry: if not registry:
raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.") raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.")
...@@ -104,78 +107,18 @@ def _get_kernel(dispatcher, input_type): ...@@ -104,78 +107,18 @@ def _get_kernel(dispatcher, input_type):
elif cls in registry: elif cls in registry:
return registry[cls] return registry[cls]
# Note that in the future we are not going to return a noop here, but rather raise the error below if allow_passthrough:
return _noop return lambda inpt, *args, **kwargs: inpt
raise TypeError( raise TypeError(
f"Dispatcher {dispatcher} supports inputs of type torch.Tensor, PIL.Image.Image, " f"Dispatcher F.{dispatcher.__name__} supports inputs of type {registry.keys()}, "
f"and subclasses of torchvision.datapoints.Datapoint, "
f"but got {input_type} instead." f"but got {input_type} instead."
) )
# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate
# stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details.
# In the future, the default behavior will be to error on unsupported types in dispatchers. The noop behavior that we
# need for transforms will be handled by _get_kernel rather than actually registering no-ops on the dispatcher.
# Finally, the use case of preventing users from registering kernels for our builtin types will be handled inside
# register_kernel.
def _register_explicit_noop(*datapoints_classes, warn_passthrough=False):
"""
Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users
from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior.
For example, without explicit no-op registration the following would be valid user code:
.. code::
from torchvision.transforms.v2 import functional as F
@F.register_kernel(F.adjust_brightness, datapoints.BoundingBox)
def lol(...):
...
"""
def decorator(dispatcher):
for cls in datapoints_classes:
msg = (
f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. "
f"This will likely change in the future."
)
_register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)(
functools.partial(_noop, __msg__=msg if warn_passthrough else None)
)
return dispatcher
return decorator
def _noop(inpt, *args, __msg__=None, **kwargs):
if __msg__:
warnings.warn(__msg__, UserWarning, stacklevel=2)
return inpt
# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that
# to error later, this decorator can be removed, since the error will be raised by _get_kernel
def _register_unsupported_type(*input_types):
def kernel(inpt, *args, __dispatcher_name__, **kwargs):
raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.")
def decorator(dispatcher):
for input_type in input_types:
_register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)(
functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)
)
return dispatcher
return decorator
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool # We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
def _register_five_ten_crop_kernel(dispatcher, input_type): def _register_five_ten_crop_kernel_internal(dispatcher, input_type):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if input_type in registry: if input_type in registry:
raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.") raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment