Unverified Commit db310636 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

move passthrough for unknown types from dispatchers to transforms (#7804)

parent 87681314
import itertools
import re
import PIL.Image
......@@ -19,7 +17,6 @@ from prototype_common_utils import make_label
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2._utils import _convert_fill_arg
from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil
from torchvision.transforms.v2.utils import check_type, is_simple_tensor
......@@ -187,66 +184,6 @@ class TestFixedSizeCrop:
assert params["needs_pad"]
assert any(pad > 0 for pad in params["padding"])
@pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2)))
def test__transform(self, mocker, needs):
fill_sentinel = 12
padding_mode_sentinel = mocker.MagicMock()
transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
transform._transformed_types = (mocker.MagicMock,)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
needs_crop, needs_pad = needs
top_sentinel = mocker.MagicMock()
left_sentinel = mocker.MagicMock()
height_sentinel = mocker.MagicMock()
width_sentinel = mocker.MagicMock()
is_valid = mocker.MagicMock() if needs_crop else None
padding_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=needs_crop,
top=top_sentinel,
left=left_sentinel,
height=height_sentinel,
width=width_sentinel,
is_valid=is_valid,
padding=padding_sentinel,
needs_pad=needs_pad,
),
)
inpt_sentinel = mocker.MagicMock()
mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop")
mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad")
transform(inpt_sentinel)
if needs_crop:
mock_crop.assert_called_once_with(
inpt_sentinel,
top=top_sentinel,
left=left_sentinel,
height=height_sentinel,
width=width_sentinel,
)
else:
mock_crop.assert_not_called()
if needs_pad:
# If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use
# `MagicMock.assert_called_once_with` and have to perform the checks manually
mock_pad.assert_called_once()
args, kwargs = mock_pad.call_args
if not needs_crop:
assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel
fill_sentinel = _convert_fill_arg(fill_sentinel)
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else:
mock_pad.assert_not_called()
def test__transform_culling(self, mocker):
batch_size = 10
canvas_size = (10, 10)
......
......@@ -27,7 +27,7 @@ from common_utils import (
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import InterpolationMode, to_pil_image
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw
......@@ -419,46 +419,6 @@ class TestPad:
with pytest.raises(ValueError, match="Padding mode should be either"):
transforms.Pad(12, padding_mode="abc")
@pytest.mark.parametrize("padding", [1, (1, 2), [1, 2, 3, 4]])
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
@pytest.mark.parametrize("padding_mode", ["constant", "edge"])
def test__transform(self, padding, fill, padding_mode, mocker):
transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode)
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
inpt = mocker.MagicMock(spec=datapoints.Image)
_ = transform(inpt)
fill = transforms._utils._convert_fill_arg(fill)
if isinstance(padding, tuple):
padding = list(padding)
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
@pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}])
def test__transform_image_mask(self, fill, mocker):
transform = transforms.Pad(1, fill=fill, padding_mode="constant")
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
image = datapoints.Image(torch.rand(3, 32, 32))
mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]
_ = transform(inpt)
if isinstance(fill, int):
fill = transforms._utils._convert_fill_arg(fill)
calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
]
else:
fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"),
]
fn.assert_has_calls(calls)
class TestRandomZoomOut:
def test_assertions(self):
......@@ -487,56 +447,6 @@ class TestRandomZoomOut:
assert 0 <= params["padding"][2] <= (side_range[1] - 1) * w
assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
@pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
def test__transform(self, fill, side_range, mocker):
inpt = make_image((24, 32))
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1)
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params([inpt])
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill)
@pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}])
def test__transform_image_mask(self, fill, mocker):
transform = transforms.RandomZoomOut(fill=fill, p=1.0)
fn = mocker.patch("torchvision.transforms.v2.functional.pad")
image = datapoints.Image(torch.rand(3, 32, 32))
mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)
if isinstance(fill, int):
fill = transforms._utils._convert_fill_arg(fill)
calls = [
mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=fill),
]
else:
fill_img = transforms._utils._convert_fill_arg(fill[type(image)])
fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)])
calls = [
mocker.call(image, **params, fill=fill_img),
mocker.call(mask, **params, fill=fill_mask),
]
fn.assert_has_calls(calls)
class TestRandomCrop:
def test_assertions(self):
......@@ -599,51 +509,6 @@ class TestRandomCrop:
assert params["needs_pad"] is any(padding)
assert params["padding"] == padding
@pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]])
@pytest.mark.parametrize("pad_if_needed", [False, True])
@pytest.mark.parametrize("fill", [False, True])
@pytest.mark.parametrize("padding_mode", ["constant", "edge"])
def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker):
output_size = [10, 12]
transform = transforms.RandomCrop(
output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode
)
h, w = size = (32, 32)
inpt = make_image(size)
if isinstance(padding, int):
new_size = (h + padding, w + padding)
elif isinstance(padding, list):
new_size = (h + sum(padding[0::2]), w + sum(padding[1::2]))
else:
new_size = size
expected = make_image(new_size)
_ = mocker.patch("torchvision.transforms.v2.functional.pad", return_value=expected)
fn_crop = mocker.patch("torchvision.transforms.v2.functional.crop")
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
params = transform._get_params([inpt])
if padding is None and not pad_if_needed:
fn_crop.assert_called_once_with(
inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
)
elif not pad_if_needed:
fn_crop.assert_called_once_with(
expected, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
)
elif padding is None:
# vfdev-5: I do not know how to mock and test this case
pass
else:
# vfdev-5: I do not know how to mock and test this case
pass
class TestGaussianBlur:
def test_assertions(self):
......@@ -675,62 +540,6 @@ class TestGaussianBlur:
assert sigma[0] <= params["sigma"][0] <= sigma[1]
assert sigma[0] <= params["sigma"][1] <= sigma[1]
@pytest.mark.parametrize("kernel_size", [3, [3, 5], (5, 3)])
@pytest.mark.parametrize("sigma", [2.0, [2.0, 3.0]])
def test__transform(self, kernel_size, sigma, mocker):
transform = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma)
if isinstance(kernel_size, (tuple, list)):
assert transform.kernel_size == kernel_size
else:
kernel_size = (kernel_size, kernel_size)
assert transform.kernel_size == kernel_size
if isinstance(sigma, (tuple, list)):
assert transform.sigma == sigma
else:
assert transform.sigma == [sigma, sigma]
fn = mocker.patch("torchvision.transforms.v2.functional.gaussian_blur")
inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3
inpt.canvas_size = (24, 32)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
params = transform._get_params([inpt])
fn.assert_called_once_with(inpt, kernel_size, **params)
class TestRandomColorOp:
@pytest.mark.parametrize("p", [0.0, 1.0])
@pytest.mark.parametrize(
"transform_cls, func_op_name, kwargs",
[
(transforms.RandomEqualize, "equalize", {}),
(transforms.RandomInvert, "invert", {}),
(transforms.RandomAutocontrast, "autocontrast", {}),
(transforms.RandomPosterize, "posterize", {"bits": 4}),
(transforms.RandomSolarize, "solarize", {"threshold": 0.5}),
(transforms.RandomAdjustSharpness, "adjust_sharpness", {"sharpness_factor": 0.5}),
],
)
def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker):
transform = transform_cls(p=p, **kwargs)
fn = mocker.patch(f"torchvision.transforms.v2.functional.{func_op_name}")
inpt = mocker.MagicMock(spec=datapoints.Image)
_ = transform(inpt)
if p > 0.0:
fn.assert_called_once_with(inpt, **kwargs)
else:
assert fn.call_count == 0
class TestRandomPerspective:
def test_assertions(self):
......@@ -751,28 +560,6 @@ class TestRandomPerspective:
assert "coefficients" in params
assert len(params["coefficients"]) == 8
@pytest.mark.parametrize("distortion_scale", [0.1, 0.7])
def test__transform(self, distortion_scale, mocker):
interpolation = InterpolationMode.BILINEAR
fill = 12
transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation)
fn = mocker.patch("torchvision.transforms.v2.functional.perspective")
inpt = make_image((24, 32))
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params([inpt])
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, None, None, **params, fill=fill, interpolation=interpolation)
class TestElasticTransform:
def test_assertions(self):
......@@ -813,35 +600,6 @@ class TestElasticTransform:
assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all()
assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all()
@pytest.mark.parametrize("alpha", [5.0, [5.0, 10.0]])
@pytest.mark.parametrize("sigma", [2.0, [2.0, 5.0]])
def test__transform(self, alpha, sigma, mocker):
interpolation = InterpolationMode.BILINEAR
fill = 12
transform = transforms.ElasticTransform(alpha, sigma=sigma, fill=fill, interpolation=interpolation)
if isinstance(alpha, float):
assert transform.alpha == [alpha, alpha]
else:
assert transform.alpha == alpha
if isinstance(sigma, float):
assert transform.sigma == [sigma, sigma]
else:
assert transform.sigma == sigma
fn = mocker.patch("torchvision.transforms.v2.functional.elastic")
inpt = mocker.MagicMock(spec=datapoints.Image)
inpt.num_channels = 3
inpt.canvas_size = (24, 32)
# Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock()
_ = transform(inpt)
params = transform._get_params([inpt])
fill = transforms._utils._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
class TestRandomErasing:
def test_assertions(self):
......@@ -889,40 +647,6 @@ class TestRandomErasing:
assert 0 <= i <= height - h
assert 0 <= j <= width - w
@pytest.mark.parametrize("p", [0, 1])
def test__transform(self, mocker, p):
transform = transforms.RandomErasing(p=p)
transform._transformed_types = (mocker.MagicMock,)
i_sentinel = mocker.MagicMock()
j_sentinel = mocker.MagicMock()
h_sentinel = mocker.MagicMock()
w_sentinel = mocker.MagicMock()
v_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.transforms.v2._augment.RandomErasing._get_params",
return_value=dict(i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel),
)
inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.transforms.v2._augment.F.erase")
output = transform(inpt_sentinel)
if p:
mock.assert_called_once_with(
inpt_sentinel,
i=i_sentinel,
j=j_sentinel,
h=h_sentinel,
w=w_sentinel,
v=v_sentinel,
inplace=transform.inplace,
)
else:
mock.assert_not_called()
assert output is inpt_sentinel
class TestTransform:
@pytest.mark.parametrize(
......@@ -1111,23 +835,12 @@ class TestRandomIoUCrop:
sample = [image, bboxes, masks]
fn = mocker.patch("torchvision.transforms.v2.functional.crop", side_effect=lambda x, **params: x)
is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool)
params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area)
transform._get_params = mocker.MagicMock(return_value=params)
output = transform(sample)
assert fn.call_count == 3
expected_calls = [
mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]),
]
fn.assert_has_calls(expected_calls)
# check number of bboxes vs number of labels:
output_bboxes = output[1]
assert isinstance(output_bboxes, datapoints.BoundingBoxes)
......@@ -1164,29 +877,6 @@ class TestScaleJitter:
assert int(canvas_size[0] * r_min) <= height <= int(canvas_size[0] * r_max)
assert int(canvas_size[1] * r_min) <= width <= int(canvas_size[1] * r_max)
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()
transform = transforms.ScaleJitter(
target_size=(16, 12), interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.transforms.v2._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel)
)
inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel)
mock.assert_called_once_with(
inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
class TestRandomShortestSize:
@pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
......@@ -1211,30 +901,6 @@ class TestRandomShortestSize:
else:
assert shorter in min_size
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()
transform = transforms.RandomShortestSize(
min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.transforms.v2._geometry.RandomShortestSize._get_params",
return_value=dict(size=size_sentinel),
)
inpt_sentinel = mocker.MagicMock()
mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel)
mock.assert_called_once_with(
inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
class TestLinearTransformation:
def test_assertions(self):
......@@ -1260,7 +926,7 @@ class TestLinearTransformation:
transform = transforms.LinearTransformation(m, v)
if isinstance(inpt, PIL.Image.Image):
with pytest.raises(TypeError, match="LinearTransformation does not work on PIL Images"):
with pytest.raises(TypeError, match="does not support PIL images"):
transform(inpt)
else:
output = transform(inpt)
......@@ -1284,30 +950,6 @@ class TestRandomResize:
assert min_size <= size < max_size
def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()
transform = transforms.RandomResize(
min_size=-1, max_size=-1, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
transform._transformed_types = (mocker.MagicMock,)
size_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.transforms.v2._geometry.RandomResize._get_params",
return_value=dict(size=size_sentinel),
)
inpt_sentinel = mocker.MagicMock()
mock_resize = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
transform(inpt_sentinel)
mock_resize.assert_called_with(
inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel
)
class TestUniformTemporalSubsample:
@pytest.mark.parametrize(
......
......@@ -1259,68 +1259,6 @@ class TestRefSegTransforms:
def test_common(self, t_ref, t, data_kwargs):
self.check(t, t_ref, data_kwargs)
def check_resize(self, mocker, t_ref, t):
mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize")
mock_ref = mocker.patch("torchvision.transforms.functional.resize")
for dp, dp_ref in self.make_datapoints():
mock.reset_mock()
mock_ref.reset_mock()
self.set_seed()
t(dp)
assert mock.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp)
)
self.set_seed()
t_ref(*dp_ref)
assert mock_ref.call_count == 2
assert all(
actual is expected
for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref)
)
for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list):
assert args_kwargs[0][1] == [args_kwargs_ref[0][1]]
def test_random_resize_train(self, mocker):
base_size = 520
min_size = base_size // 2
max_size = base_size * 2
randint = torch.randint
def patched_randint(a, b, *other_args, **kwargs):
if kwargs or len(other_args) > 1 or other_args[0] != ():
return randint(a, b, *other_args, **kwargs)
return random.randint(a, b)
# We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
# normally
t = v2_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
mocker.patch(
"torchvision.transforms.v2._geometry.torch.randint",
new=patched_randint,
)
t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size)
self.check_resize(mocker, t_ref, t)
def test_random_resize_eval(self, mocker):
torch.manual_seed(0)
base_size = 520
t = v2_transforms.Resize(size=base_size, antialias=True)
t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size)
self.check_resize(mocker, t_ref, t)
@pytest.mark.parametrize(
("legacy_dispatcher", "name_only_params"),
......
......@@ -39,7 +39,7 @@ from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
@pytest.fixture(autouse=True)
......@@ -376,35 +376,6 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
return torch.stack([transform(b) for b in bounding_boxes.reshape(-1, 4).unbind()]).reshape(bounding_boxes.shape)
@pytest.mark.parametrize(
("dispatcher", "registered_input_types"),
[(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()],
)
def test_exhaustive_kernel_registration(dispatcher, registered_input_types):
missing = {
torch.Tensor,
PIL.Image.Image,
datapoints.Image,
datapoints.BoundingBoxes,
datapoints.Mask,
datapoints.Video,
} - registered_input_types
if missing:
names = sorted(str(t) for t in missing)
raise AssertionError(
"\n".join(
[
f"The dispatcher '{dispatcher.__name__}' has no kernel registered for",
"",
*[f"- {name}" for name in names],
"",
f"If available, register the kernels with @_register_kernel_internal({dispatcher.__name__}, ...).",
f"If not, register explicit no-ops with @_register_explicit_noop({', '.join(names)})",
]
)
)
class TestResize:
INPUT_SIZE = (17, 11)
OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)]
......@@ -2128,9 +2099,20 @@ class TestRegisterKernel:
with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"):
F.register_kernel(F.resize, object)
with pytest.raises(ValueError, match="already has a kernel registered for type"):
with pytest.raises(ValueError, match="cannot be registered for the builtin datapoint classes"):
F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor)
class CustomDatapoint(datapoints.Datapoint):
pass
def resize_custom_datapoint():
pass
F.register_kernel(F.resize, CustomDatapoint)(resize_custom_datapoint)
with pytest.raises(ValueError, match="already has a kernel registered for type"):
F.register_kernel(F.resize, CustomDatapoint)(resize_custom_datapoint)
class TestGetKernel:
# We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination
......@@ -2152,13 +2134,7 @@ class TestGetKernel:
pass
for input_type in [str, int, object, MyTensor, MyPILImage]:
with pytest.raises(
TypeError,
match=(
"supports inputs of type torch.Tensor, PIL.Image.Image, "
"and subclasses of torchvision.datapoints.Datapoint"
),
):
with pytest.raises(TypeError, match="supports inputs of type"):
_get_kernel(F.resize, input_type)
def test_exact_match(self):
......@@ -2211,8 +2187,8 @@ class TestGetKernel:
class MyDatapoint(datapoints.Datapoint):
pass
# Note that this will be an error in the future
assert _get_kernel(F.resize, MyDatapoint) is _noop
with pytest.raises(TypeError, match="supports inputs of type"):
_get_kernel(F.resize, MyDatapoint)
def resize_my_datapoint():
pass
......
......@@ -101,7 +101,8 @@ class FixedSizeCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_crop"]:
inpt = F.crop(
inpt = self._call_kernel(
F.crop,
inpt,
top=params["top"],
left=params["left"],
......@@ -120,6 +121,6 @@ class FixedSizeCrop(Transform):
if params["needs_pad"]:
fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
inpt = self._call_kernel(F.pad, inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt
import math
import numbers
import warnings
from typing import Any, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Tuple
import PIL.Image
import torch
......@@ -91,6 +91,14 @@ class RandomErasing(_RandomApplyTransform):
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)
......@@ -131,7 +139,7 @@ class RandomErasing(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace)
inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace)
return inpt
......
import collections.abc
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
from torchvision import datapoints, transforms as _transforms
from torchvision import transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform
from ._transform import _RandomApplyTransform
from .utils import is_simple_tensor, query_chw
from .utils import query_chw
class Grayscale(Transform):
......@@ -24,19 +23,12 @@ class Grayscale(Transform):
_v1_transform_cls = _transforms.Grayscale
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, num_output_channels: int = 1):
super().__init__()
self.num_output_channels = num_output_channels
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels)
class RandomGrayscale(_RandomApplyTransform):
......@@ -55,13 +47,6 @@ class RandomGrayscale(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomGrayscale
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, p: float = 0.1) -> None:
super().__init__(p=p)
......@@ -70,7 +55,7 @@ class RandomGrayscale(_RandomApplyTransform):
return dict(num_input_channels=num_input_channels)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"])
class ColorJitter(Transform):
......@@ -167,13 +152,13 @@ class ColorJitter(Transform):
hue_factor = params["hue_factor"]
for fn_id in params["fn_idx"]:
if fn_id == 0 and brightness_factor is not None:
output = F.adjust_brightness(output, brightness_factor=brightness_factor)
output = self._call_kernel(F.adjust_brightness, output, brightness_factor=brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
output = F.adjust_contrast(output, contrast_factor=contrast_factor)
output = self._call_kernel(F.adjust_contrast, output, contrast_factor=contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
output = F.adjust_saturation(output, saturation_factor=saturation_factor)
output = self._call_kernel(F.adjust_saturation, output, saturation_factor=saturation_factor)
elif fn_id == 3 and hue_factor is not None:
output = F.adjust_hue(output, hue_factor=hue_factor)
output = self._call_kernel(F.adjust_hue, output, hue_factor=hue_factor)
return output
......@@ -183,19 +168,12 @@ class RandomChannelPermutation(Transform):
.. v2betastatus:: RandomChannelPermutation transform
"""
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
return dict(permutation=torch.randperm(num_channels))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.permute_channels(inpt, params["permutation"])
return self._call_kernel(F.permute_channels, inpt, params["permutation"])
class RandomPhotometricDistort(Transform):
......@@ -224,13 +202,6 @@ class RandomPhotometricDistort(Transform):
Default is 0.5.
"""
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(
self,
brightness: Tuple[float, float] = (0.875, 1.125),
......@@ -263,17 +234,17 @@ class RandomPhotometricDistort(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["brightness_factor"] is not None:
inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"])
inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"])
if params["contrast_factor"] is not None and params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
if params["saturation_factor"] is not None:
inpt = F.adjust_saturation(inpt, saturation_factor=params["saturation_factor"])
inpt = self._call_kernel(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"])
if params["hue_factor"] is not None:
inpt = F.adjust_hue(inpt, hue_factor=params["hue_factor"])
inpt = self._call_kernel(F.adjust_hue, inpt, hue_factor=params["hue_factor"])
if params["contrast_factor"] is not None and not params["contrast_before"]:
inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"])
inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"])
if params["channel_permutation"] is not None:
inpt = F.permute_channels(inpt, permutation=params["channel_permutation"])
inpt = self._call_kernel(F.permute_channels, inpt, permutation=params["channel_permutation"])
return inpt
......@@ -293,7 +264,7 @@ class RandomEqualize(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomEqualize
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt)
return self._call_kernel(F.equalize, inpt)
class RandomInvert(_RandomApplyTransform):
......@@ -312,7 +283,7 @@ class RandomInvert(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomInvert
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.invert(inpt)
return self._call_kernel(F.invert, inpt)
class RandomPosterize(_RandomApplyTransform):
......@@ -337,7 +308,7 @@ class RandomPosterize(_RandomApplyTransform):
self.bits = bits
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.posterize(inpt, bits=self.bits)
return self._call_kernel(F.posterize, inpt, bits=self.bits)
class RandomSolarize(_RandomApplyTransform):
......@@ -362,7 +333,7 @@ class RandomSolarize(_RandomApplyTransform):
self.threshold = threshold
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.solarize(inpt, threshold=self.threshold)
return self._call_kernel(F.solarize, inpt, threshold=self.threshold)
class RandomAutocontrast(_RandomApplyTransform):
......@@ -381,7 +352,7 @@ class RandomAutocontrast(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomAutocontrast
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.autocontrast(inpt)
return self._call_kernel(F.autocontrast, inpt)
class RandomAdjustSharpness(_RandomApplyTransform):
......@@ -406,4 +377,4 @@ class RandomAdjustSharpness(_RandomApplyTransform):
self.sharpness_factor = sharpness_factor
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor)
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor)
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
import PIL.Image
import torch
......@@ -44,7 +44,7 @@ class RandomHorizontalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomHorizontalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.horizontal_flip(inpt)
return self._call_kernel(F.horizontal_flip, inpt)
class RandomVerticalFlip(_RandomApplyTransform):
......@@ -64,7 +64,7 @@ class RandomVerticalFlip(_RandomApplyTransform):
_v1_transform_cls = _transforms.RandomVerticalFlip
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.vertical_flip(inpt)
return self._call_kernel(F.vertical_flip, inpt)
class Resize(Transform):
......@@ -152,7 +152,8 @@ class Resize(Transform):
self.antialias = antialias
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(
return self._call_kernel(
F.resize,
inpt,
self.size,
interpolation=self.interpolation,
......@@ -186,7 +187,7 @@ class CenterCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.center_crop(inpt, output_size=self.size)
return self._call_kernel(F.center_crop, inpt, output_size=self.size)
class RandomResizedCrop(Transform):
......@@ -307,8 +308,8 @@ class RandomResizedCrop(Transform):
return dict(top=i, left=j, height=h, width=w)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resized_crop(
inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
return self._call_kernel(
F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias
)
......@@ -357,8 +358,16 @@ class FiveCrop(Transform):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.five_crop(inpt, self.size)
return self._call_kernel(F.five_crop, inpt, self.size)
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask):
......@@ -396,12 +405,20 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"datapoints.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(dispatcher, inpt, *args, **kwargs)
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask):
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
return self._call_kernel(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip)
class Pad(Transform):
......@@ -475,7 +492,7 @@ class Pad(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
class RandomZoomOut(_RandomApplyTransform):
......@@ -545,7 +562,7 @@ class RandomZoomOut(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, **params, fill=fill)
return self._call_kernel(F.pad, inpt, **params, fill=fill)
class RandomRotation(Transform):
......@@ -611,7 +628,8 @@ class RandomRotation(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt))
return F.rotate(
return self._call_kernel(
F.rotate,
inpt,
**params,
interpolation=self.interpolation,
......@@ -733,7 +751,8 @@ class RandomAffine(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt))
return F.affine(
return self._call_kernel(
F.affine,
inpt,
**params,
interpolation=self.interpolation,
......@@ -889,10 +908,12 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
inpt = self._call_kernel(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]:
inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
inpt = self._call_kernel(
F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
)
return inpt
......@@ -973,7 +994,8 @@ class RandomPerspective(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt))
return F.perspective(
return self._call_kernel(
F.perspective,
inpt,
None,
None,
......@@ -1050,7 +1072,7 @@ class ElasticTransform(Transform):
# if kernel size is even we have to make it odd
if kx % 2 == 0:
kx += 1
dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma))
dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma))
dx = dx * self.alpha[0] / size[0]
dy = torch.rand([1, 1] + size) * 2 - 1
......@@ -1059,14 +1081,15 @@ class ElasticTransform(Transform):
# if kernel size is even we have to make it odd
if ky % 2 == 0:
ky += 1
dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma))
dy = self._call_kernel(F.gaussian_blur, dy, [ky, ky], list(self.sigma))
dy = dy * self.alpha[1] / size[1]
displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self._fill, type(inpt))
return F.elastic(
return self._call_kernel(
F.elastic,
inpt,
**params,
fill=fill,
......@@ -1164,7 +1187,9 @@ class RandomIoUCrop(Transform):
# check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_boxes(
bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY
bboxes.as_subclass(torch.Tensor),
bboxes.format,
datapoints.BoundingBoxFormat.XYXY,
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
......@@ -1188,7 +1213,9 @@ class RandomIoUCrop(Transform):
if len(params) < 1:
return inpt
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
output = self._call_kernel(
F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]
)
if isinstance(output, datapoints.BoundingBoxes):
# We "mark" the invalid boxes as degenreate, and they can be
......@@ -1262,7 +1289,9 @@ class ScaleJitter(Transform):
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
return self._call_kernel(
F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias
)
class RandomShortestSize(Transform):
......@@ -1330,7 +1359,9 @@ class RandomShortestSize(Transform):
return dict(size=(new_height, new_width))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias)
return self._call_kernel(
F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias
)
class RandomResize(Transform):
......@@ -1400,4 +1431,6 @@ class RandomResize(Transform):
return dict(size=[size])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias)
return self._call_kernel(
F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias
)
......@@ -106,7 +106,7 @@ class LinearTransformation(Transform):
def _check_inputs(self, sample: Any) -> Any:
if has_any(sample, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images")
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
shape = inpt.shape
......@@ -157,7 +157,6 @@ class Normalize(Transform):
"""
_v1_transform_cls = _transforms.Normalize
_transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video)
def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__()
......@@ -170,7 +169,7 @@ class Normalize(Transform):
raise TypeError(f"{type(self).__name__}() does not support PIL images.")
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
return self._call_kernel(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace)
class GaussianBlur(Transform):
......@@ -217,7 +216,7 @@ class GaussianBlur(Transform):
return dict(sigma=[sigma, sigma])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.gaussian_blur(inpt, self.kernel_size, **params)
return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params)
class ToDtype(Transform):
......@@ -290,7 +289,7 @@ class ToDtype(Transform):
)
return inpt
return F.to_dtype(inpt, dtype=dtype, scale=self.scale)
return self._call_kernel(F.to_dtype, inpt, dtype=dtype, scale=self.scale)
class ConvertImageDtype(Transform):
......@@ -320,14 +319,12 @@ class ConvertImageDtype(Transform):
_v1_transform_cls = _transforms.ConvertImageDtype
_transformed_types = (is_simple_tensor, datapoints.Image)
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
super().__init__()
self.dtype = dtype
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.to_dtype(inpt, dtype=self.dtype, scale=True)
return self._call_kernel(F.to_dtype, inpt, dtype=self.dtype, scale=True)
class SanitizeBoundingBoxes(Transform):
......
......@@ -25,4 +25,4 @@ class UniformTemporalSubsample(Transform):
self.num_samples = num_samples
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.uniform_temporal_subsample(inpt, self.num_samples)
return self._call_kernel(F.uniform_temporal_subsample, inpt, self.num_samples)
......@@ -11,6 +11,8 @@ from torchvision import datapoints
from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor
from torchvision.utils import _log_api_usage_once
from .functional._utils import _get_kernel
class Transform(nn.Module):
......@@ -28,6 +30,10 @@ class Transform(nn.Module):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict()
def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
kernel = _get_kernel(dispatcher, type(inpt), allow_passthrough=True)
return kernel(inpt, *args, **kwargs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
raise NotImplementedError
......
......@@ -5,10 +5,9 @@ from torchvision import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
from ._utils import _get_kernel, _register_kernel_internal
@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
def erase(
inpt: torch.Tensor,
i: int,
......
......@@ -10,12 +10,10 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once
from ._misc import _num_value_bits, to_dtype_image_tensor
from ._type_conversion import pil_to_tensor, to_image_pil
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
from ._utils import _get_kernel, _register_kernel_internal
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video)
def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
if torch.jit.is_scripting():
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels)
......@@ -70,8 +68,8 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te
return output if fp else output.to(image1.dtype)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if torch.jit.is_scripting():
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
......@@ -107,7 +105,6 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if torch.jit.is_scripting():
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
......@@ -146,7 +143,6 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to
return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if torch.jit.is_scripting():
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
......@@ -185,7 +181,6 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.
return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
if torch.jit.is_scripting():
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
......@@ -258,7 +253,6 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor:
if torch.jit.is_scripting():
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
......@@ -370,7 +364,6 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image_tensor(video, hue_factor=hue_factor)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
if torch.jit.is_scripting():
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
......@@ -410,7 +403,6 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to
return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
if torch.jit.is_scripting():
return posterize_image_tensor(inpt, bits=bits)
......@@ -444,7 +436,6 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image_tensor(video, bits=bits)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor:
if torch.jit.is_scripting():
return solarize_image_tensor(inpt, threshold=threshold)
......@@ -472,7 +463,6 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image_tensor(video, threshold=threshold)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def autocontrast(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting():
return autocontrast_image_tensor(inpt)
......@@ -522,7 +512,6 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image_tensor(video)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def equalize(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting():
return equalize_image_tensor(inpt)
......@@ -612,7 +601,6 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def invert(inpt: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting():
return invert_image_tensor(inpt)
......@@ -643,7 +631,6 @@ def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor:
"""Permute the channels of the input according to the given permutation.
......
......@@ -25,13 +25,7 @@ from torchvision.utils import _log_api_usage_once
from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil
from ._utils import (
_FillTypeJIT,
_get_kernel,
_register_explicit_noop,
_register_five_ten_crop_kernel,
_register_kernel_internal,
)
from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
......@@ -2203,7 +2197,6 @@ def resized_crop_video(
)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True)
def five_crop(
inpt: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
......@@ -2230,8 +2223,8 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
return size
@_register_five_ten_crop_kernel(five_crop, torch.Tensor)
@_register_five_ten_crop_kernel(five_crop, datapoints.Image)
@_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor)
@_register_five_ten_crop_kernel_internal(five_crop, datapoints.Image)
def five_crop_image_tensor(
image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
......@@ -2250,7 +2243,7 @@ def five_crop_image_tensor(
return tl, tr, bl, br, center
@_register_five_ten_crop_kernel(five_crop, PIL.Image.Image)
@_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image)
def five_crop_image_pil(
image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
......@@ -2269,14 +2262,13 @@ def five_crop_image_pil(
return tl, tr, bl, br, center
@_register_five_ten_crop_kernel(five_crop, datapoints.Video)
@_register_five_ten_crop_kernel_internal(five_crop, datapoints.Video)
def five_crop_video(
video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return five_crop_image_tensor(video, size)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True)
def ten_crop(
inpt: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
......@@ -2300,8 +2292,8 @@ def ten_crop(
return kernel(inpt, size=size, vertical_flip=vertical_flip)
@_register_five_ten_crop_kernel(ten_crop, torch.Tensor)
@_register_five_ten_crop_kernel(ten_crop, datapoints.Image)
@_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor)
@_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Image)
def ten_crop_image_tensor(
image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
......@@ -2328,7 +2320,7 @@ def ten_crop_image_tensor(
return non_flipped + flipped
@_register_five_ten_crop_kernel(ten_crop, PIL.Image.Image)
@_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image)
def ten_crop_image_pil(
image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
......@@ -2355,7 +2347,7 @@ def ten_crop_image_pil(
return non_flipped + flipped
@_register_five_ten_crop_kernel(ten_crop, datapoints.Video)
@_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Video)
def ten_crop_video(
video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
......
......@@ -8,10 +8,9 @@ from torchvision.transforms import _functional_pil as _FP
from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_kernel_internal, _register_unsupported_type, is_simple_tensor
from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_dimensions(inpt: torch.Tensor) -> List[int]:
if torch.jit.is_scripting():
return get_dimensions_image_tensor(inpt)
......@@ -44,7 +43,6 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]:
return get_dimensions_image_tensor(video)
@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask)
def get_num_channels(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting():
return get_num_channels_image_tensor(inpt)
......@@ -123,7 +121,6 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int]
return list(bounding_box.canvas_size)
@_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask)
def get_num_frames(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting():
return get_num_frames_video(inpt)
......
......@@ -11,11 +11,9 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, _register_unsupported_type
from ._utils import _get_kernel, _register_kernel_internal
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
@_register_unsupported_type(PIL.Image.Image)
def normalize(
inpt: torch.Tensor,
mean: List[float],
......@@ -73,7 +71,6 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in
return normalize_image_tensor(video, mean, std, inplace=inplace)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor:
if torch.jit.is_scripting():
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
......@@ -182,7 +179,6 @@ def gaussian_blur_video(
return gaussian_blur_image_tensor(video, kernel_size, sigma)
@_register_unsupported_type(PIL.Image.Image)
def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor:
if torch.jit.is_scripting():
return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale)
......
import PIL.Image
import torch
from torchvision import datapoints
from torchvision.utils import _log_api_usage_once
from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal
from ._utils import _get_kernel, _register_kernel_internal
@_register_explicit_noop(
PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True
)
def uniform_temporal_subsample(inpt: torch.Tensor, num_samples: int) -> torch.Tensor:
if torch.jit.is_scripting():
return uniform_temporal_subsample_video(inpt, num_samples=num_samples)
......
import functools
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
import torch
......@@ -53,6 +52,11 @@ def _name_to_dispatcher(name):
) from None
_BUILTIN_DATAPOINT_TYPES = {
obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, datapoints.Datapoint)
}
def register_kernel(dispatcher, datapoint_cls):
"""Decorate a kernel to register it for a dispatcher and a (custom) datapoint type.
......@@ -70,20 +74,19 @@ def register_kernel(dispatcher, datapoint_cls):
f"but got {dispatcher}."
)
if not (
isinstance(datapoint_cls, type)
and issubclass(datapoint_cls, datapoints.Datapoint)
and datapoint_cls is not datapoints.Datapoint
):
if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)):
raise ValueError(
f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, "
f"but got {datapoint_cls}."
)
if datapoint_cls in _BUILTIN_DATAPOINT_TYPES:
raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}")
return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False)
def _get_kernel(dispatcher, input_type):
def _get_kernel(dispatcher, input_type, *, allow_passthrough=False):
registry = _KERNEL_REGISTRY.get(dispatcher)
if not registry:
raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.")
......@@ -104,78 +107,18 @@ def _get_kernel(dispatcher, input_type):
elif cls in registry:
return registry[cls]
# Note that in the future we are not going to return a noop here, but rather raise the error below
return _noop
if allow_passthrough:
return lambda inpt, *args, **kwargs: inpt
raise TypeError(
f"Dispatcher {dispatcher} supports inputs of type torch.Tensor, PIL.Image.Image, "
f"and subclasses of torchvision.datapoints.Datapoint, "
f"Dispatcher F.{dispatcher.__name__} supports inputs of type {registry.keys()}, "
f"but got {input_type} instead."
)
# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate
# stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details.
# In the future, the default behavior will be to error on unsupported types in dispatchers. The noop behavior that we
# need for transforms will be handled by _get_kernel rather than actually registering no-ops on the dispatcher.
# Finally, the use case of preventing users from registering kernels for our builtin types will be handled inside
# register_kernel.
def _register_explicit_noop(*datapoints_classes, warn_passthrough=False):
"""
Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users
from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior.
For example, without explicit no-op registration the following would be valid user code:
.. code::
from torchvision.transforms.v2 import functional as F
@F.register_kernel(F.adjust_brightness, datapoints.BoundingBox)
def lol(...):
...
"""
def decorator(dispatcher):
for cls in datapoints_classes:
msg = (
f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. "
f"This will likely change in the future."
)
_register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)(
functools.partial(_noop, __msg__=msg if warn_passthrough else None)
)
return dispatcher
return decorator
def _noop(inpt, *args, __msg__=None, **kwargs):
if __msg__:
warnings.warn(__msg__, UserWarning, stacklevel=2)
return inpt
# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that
# to error later, this decorator can be removed, since the error will be raised by _get_kernel
def _register_unsupported_type(*input_types):
def kernel(inpt, *args, __dispatcher_name__, **kwargs):
raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.")
def decorator(dispatcher):
for input_type in input_types:
_register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)(
functools.partial(kernel, __dispatcher_name__=dispatcher.__name__)
)
return dispatcher
return decorator
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool
def _register_five_ten_crop_kernel(dispatcher, input_type):
def _register_five_ten_crop_kernel_internal(dispatcher, input_type):
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if input_type in registry:
raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment