Unverified Commit a893f313 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

refactor Datapoint dispatch mechanism (#7747)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 16d62e30
...@@ -829,6 +829,10 @@ def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs): ...@@ -829,6 +829,10 @@ def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs)) return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
def make_video_tensor(*args, **kwargs):
return make_video(*args, **kwargs).as_subclass(torch.Tensor)
def make_video_loader( def make_video_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE, size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
*, *,
......
...@@ -567,7 +567,7 @@ class DatasetTestCase(unittest.TestCase): ...@@ -567,7 +567,7 @@ class DatasetTestCase(unittest.TestCase):
@test_all_configs @test_all_configs
def test_transforms_v2_wrapper(self, config): def test_transforms_v2_wrapper(self, config):
from torchvision.datapoints._datapoint import Datapoint from torchvision import datapoints
from torchvision.datasets import wrap_dataset_for_transforms_v2 from torchvision.datasets import wrap_dataset_for_transforms_v2
try: try:
...@@ -588,7 +588,9 @@ class DatasetTestCase(unittest.TestCase): ...@@ -588,7 +588,9 @@ class DatasetTestCase(unittest.TestCase):
assert len(wrapped_dataset) == info["num_examples"] assert len(wrapped_dataset) == info["num_examples"]
wrapped_sample = wrapped_dataset[0] wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) assert tree_any(
lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample
)
except TypeError as error: except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}" msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
if str(error).startswith(msg): if str(error).startswith(msg):
......
...@@ -1344,12 +1344,12 @@ def test_antialias_warning(): ...@@ -1344,12 +1344,12 @@ def test_antialias_warning():
transforms.RandomResize(10, 20)(tensor_img) transforms.RandomResize(10, 20)(tensor_img)
with pytest.warns(UserWarning, match=match): with pytest.warns(UserWarning, match=match):
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20)) F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20))
with pytest.warns(UserWarning, match=match): with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resize((20, 20)) F.resize(datapoints.Video(tensor_video), (20, 20))
with pytest.warns(UserWarning, match=match): with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20)) F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20))
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("error") warnings.simplefilter("error")
...@@ -1363,8 +1363,8 @@ def test_antialias_warning(): ...@@ -1363,8 +1363,8 @@ def test_antialias_warning():
transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img) transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
transforms.RandomResize(10, 20, antialias=True)(tensor_img) transforms.RandomResize(10, 20, antialias=True)(tensor_img)
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20), antialias=True) F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20), antialias=True)
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20), antialias=True) F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20), antialias=True)
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
......
...@@ -2,13 +2,11 @@ import inspect ...@@ -2,13 +2,11 @@ import inspect
import math import math
import os import os
import re import re
from unittest import mock
from typing import get_type_hints
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
import torch import torch
from common_utils import ( from common_utils import (
...@@ -27,6 +25,7 @@ from torchvision.transforms.functional import _get_perspective_coeffs ...@@ -27,6 +25,7 @@ from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
from torchvision.transforms.v2.utils import is_simple_tensor from torchvision.transforms.v2.utils import is_simple_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS
...@@ -424,12 +423,18 @@ class TestDispatchers: ...@@ -424,12 +423,18 @@ class TestDispatchers:
def test_dispatch_datapoint(self, info, args_kwargs, spy_on): def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
(datapoint, *other_args), kwargs = args_kwargs.load() (datapoint, *other_args), kwargs = args_kwargs.load()
method_name = info.id input_type = type(datapoint)
method = getattr(datapoint, method_name)
datapoint_type = type(datapoint) wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type]
spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}")
info.dispatcher(datapoint, *other_args, **kwargs) # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if hasattr(wrapped_kernel, "__wrapped__"):
assert wrapped_kernel.__wrapped__ is info.kernels[input_type]
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}):
info.dispatcher(datapoint, *other_args, **kwargs)
spy.assert_called_once() spy.assert_called_once()
...@@ -462,9 +467,12 @@ class TestDispatchers: ...@@ -462,9 +467,12 @@ class TestDispatchers:
kernel_params = list(kernel_signature.parameters.values())[1:] kernel_params = list(kernel_signature.parameters.values())[1:]
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
# explicit passed to the kernel. # explicitly passed to the kernel.
datapoint_type_metadata = datapoint_type.__annotations__.keys() input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
kernel_params = [param for param in kernel_params if param.name not in datapoint_type_metadata] explicit_metadata = {
datapoints.BoundingBoxes: {"format", "canvas_size"},
}
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
dispatcher_params = iter(dispatcher_params) dispatcher_params = iter(dispatcher_params)
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params): for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
...@@ -481,28 +489,6 @@ class TestDispatchers: ...@@ -481,28 +489,6 @@ class TestDispatchers:
assert dispatcher_param == kernel_param assert dispatcher_param == kernel_param
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_dispatcher_datapoint_signatures_consistency(self, info):
try:
datapoint_method = getattr(datapoints._datapoint.Datapoint, info.id)
except AttributeError:
pytest.skip("Dispatcher doesn't support arbitrary datapoint dispatch.")
dispatcher_signature = inspect.signature(info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
datapoint_signature = inspect.signature(datapoint_method)
datapoint_params = list(datapoint_signature.parameters.values())[1:]
# Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
# defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
# natively concrete dispatcher annotations.
datapoint_annotations = get_type_hints(datapoint_method)
for param in datapoint_params:
param._annotation = datapoint_annotations[param.name]
assert dispatcher_params == datapoint_params
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_unkown_type(self, info): def test_unkown_type(self, info):
unkown_input = object() unkown_input = object()
......
...@@ -3,7 +3,6 @@ import decimal ...@@ -3,7 +3,6 @@ import decimal
import inspect import inspect
import math import math
import re import re
from typing import get_type_hints
from unittest import mock from unittest import mock
import numpy as np import numpy as np
...@@ -26,6 +25,7 @@ from common_utils import ( ...@@ -26,6 +25,7 @@ from common_utils import (
make_image_tensor, make_image_tensor,
make_segmentation_mask, make_segmentation_mask,
make_video, make_video,
make_video_tensor,
needs_cuda, needs_cuda,
set_rng_seed, set_rng_seed,
) )
...@@ -39,6 +39,7 @@ from torchvision import datapoints ...@@ -39,6 +39,7 @@ from torchvision import datapoints
from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.functional import pil_modes_mapping
from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -176,16 +177,19 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs): ...@@ -176,16 +177,19 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
"""Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is """Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is
preserved in doing so. For bounding boxes also checks that the format is preserved. preserved in doing so. For bounding boxes also checks that the format is preserved.
""" """
if isinstance(input, datapoints._datapoint.Datapoint): input_type = type(input)
# Due to our complex dispatch architecture for datapoints, we cannot spy on the kernel directly,
# but rather have to patch the `Datapoint.__F` attribute to contain the spied on kernel. if isinstance(input, datapoints.Datapoint):
spy = mock.MagicMock(wraps=kernel, name=kernel.__name__) wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type]
with mock.patch.object(F, kernel.__name__, spy):
# Due to Python's name mangling, the `Datapoint.__F` attribute is only accessible from inside the class. # In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# Since that is not the case here, we need to prefix f"_{cls.__name__}" # proper kernel was wrapped
# See https://docs.python.org/3/tutorial/classes.html#private-variables for details if hasattr(wrapped_kernel, "__wrapped__"):
with mock.patch.object(datapoints._datapoint.Datapoint, "_Datapoint__F", new=F): assert wrapped_kernel.__wrapped__ is kernel
output = dispatcher(input, *args, **kwargs)
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}):
output = dispatcher(input, *args, **kwargs)
spy.assert_called_once() spy.assert_called_once()
else: else:
...@@ -194,7 +198,7 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs): ...@@ -194,7 +198,7 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
spy.assert_called_once() spy.assert_called_once()
assert isinstance(output, type(input)) assert isinstance(output, input_type)
if isinstance(input, datapoints.BoundingBoxes): if isinstance(input, datapoints.BoundingBoxes):
assert output.format == input.format assert output.format == input.format
...@@ -209,15 +213,13 @@ def check_dispatcher( ...@@ -209,15 +213,13 @@ def check_dispatcher(
check_dispatch=True, check_dispatch=True,
**kwargs, **kwargs,
): ):
unknown_input = object()
with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy: with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
dispatcher(input, *args, **kwargs) with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
dispatcher(unknown_input, *args, **kwargs)
spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}") spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}")
unknown_input = object()
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
dispatcher(unknown_input, *args, **kwargs)
if check_scripted_smoke: if check_scripted_smoke:
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs) _check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs)
...@@ -225,18 +227,18 @@ def check_dispatcher( ...@@ -225,18 +227,18 @@ def check_dispatcher(
_check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs) _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs)
def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
"""Checks if the signature of the dispatcher matches the kernel signature.""" """Checks if the signature of the dispatcher matches the kernel signature."""
dispatcher_signature = inspect.signature(dispatcher) dispatcher_params = list(inspect.signature(dispatcher).parameters.values())[1:]
dispatcher_params = list(dispatcher_signature.parameters.values())[1:] kernel_params = list(inspect.signature(kernel).parameters.values())[1:]
kernel_signature = inspect.signature(kernel)
kernel_params = list(kernel_signature.parameters.values())[1:]
if issubclass(input_type, datapoints._datapoint.Datapoint): if issubclass(input_type, datapoints.Datapoint):
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be # We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
# explicitly passed to the kernel. # explicitly passed to the kernel.
kernel_params = [param for param in kernel_params if param.name not in input_type.__annotations__.keys()] explicit_metadata = {
datapoints.BoundingBoxes: {"format", "canvas_size"},
}
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
dispatcher_params = iter(dispatcher_params) dispatcher_params = iter(dispatcher_params)
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params): for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
...@@ -259,30 +261,6 @@ def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type): ...@@ -259,30 +261,6 @@ def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
assert dispatcher_param == kernel_param assert dispatcher_param == kernel_param
def _check_dispatcher_datapoint_signature_match(dispatcher):
"""Checks if the signature of the dispatcher matches the corresponding method signature on the Datapoint class."""
dispatcher_signature = inspect.signature(dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
datapoint_method = getattr(datapoints._datapoint.Datapoint, dispatcher.__name__)
datapoint_signature = inspect.signature(datapoint_method)
datapoint_params = list(datapoint_signature.parameters.values())[1:]
# Some annotations in the `datapoints._datapoint` module
# are stored as strings. The block below makes them concrete again (non-strings), so they can be compared to the
# natively concrete dispatcher annotations.
datapoint_annotations = get_type_hints(datapoint_method)
for param in datapoint_params:
param._annotation = datapoint_annotations[param.name]
assert dispatcher_params == datapoint_params
def check_dispatcher_signatures_match(dispatcher, *, kernel, input_type):
_check_dispatcher_kernel_signature_match(dispatcher, kernel=kernel, input_type=input_type)
_check_dispatcher_datapoint_signature_match(dispatcher)
def _check_transform_v1_compatibility(transform, input): def _check_transform_v1_compatibility(transform, input):
"""If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static """If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
``get_params`` method, is scriptable, and the scripted version can be called without error.""" ``get_params`` method, is scriptable, and the scripted version can be called without error."""
...@@ -433,6 +411,33 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -433,6 +411,33 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
return torch.stack([transform(b) for b in bounding_boxes.reshape(-1, 4).unbind()]).reshape(bounding_boxes.shape) return torch.stack([transform(b) for b in bounding_boxes.reshape(-1, 4).unbind()]).reshape(bounding_boxes.shape)
@pytest.mark.parametrize(
("dispatcher", "registered_datapoint_clss"),
[(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()],
)
def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss):
missing = {
datapoints.Image,
datapoints.BoundingBoxes,
datapoints.Mask,
datapoints.Video,
} - registered_datapoint_clss
if missing:
names = sorted(f"datapoints.{cls.__name__}" for cls in missing)
raise AssertionError(
"\n".join(
[
f"The dispatcher '{dispatcher.__name__}' has no kernel registered for",
"",
*[f"- {name}" for name in names],
"",
f"If available, register the kernels with @_register_kernel_internal({dispatcher.__name__}, ...).",
f"If not, register explicit no-ops with @_register_explicit_noop({', '.join(names)})",
]
)
)
class TestResize: class TestResize:
INPUT_SIZE = (17, 11) INPUT_SIZE = (17, 11)
OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)] OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)]
...@@ -568,7 +573,7 @@ class TestResize: ...@@ -568,7 +573,7 @@ class TestResize:
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_signatures_match(F.resize, kernel=kernel, input_type=input_type) check_dispatcher_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize("size", OUTPUT_SIZES) @pytest.mark.parametrize("size", OUTPUT_SIZES)
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
...@@ -766,7 +771,7 @@ class TestResize: ...@@ -766,7 +771,7 @@ class TestResize:
# This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there # This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there
# is a good reason to break this, feel free to downgrade to an equality check. # is a good reason to break this, feel free to downgrade to an equality check.
if isinstance(input, datapoints._datapoint.Datapoint): if isinstance(input, datapoints.Datapoint):
# We can't test identity directly, since that checks for the identity of the Python object. Since all # We can't test identity directly, since that checks for the identity of the Python object. Since all
# datapoints unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check # datapoints unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check
# that the underlying storage is the same # that the underlying storage is the same
...@@ -850,7 +855,7 @@ class TestHorizontalFlip: ...@@ -850,7 +855,7 @@ class TestHorizontalFlip:
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_signatures_match(F.horizontal_flip, kernel=kernel, input_type=input_type) check_dispatcher_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"make_input", "make_input",
...@@ -1033,7 +1038,7 @@ class TestAffine: ...@@ -1033,7 +1038,7 @@ class TestAffine:
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_signatures_match(F.affine, kernel=kernel, input_type=input_type) check_dispatcher_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"make_input", "make_input",
...@@ -1329,7 +1334,7 @@ class TestVerticalFlip: ...@@ -1329,7 +1334,7 @@ class TestVerticalFlip:
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_signatures_match(F.vertical_flip, kernel=kernel, input_type=input_type) check_dispatcher_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"make_input", "make_input",
...@@ -1486,7 +1491,7 @@ class TestRotate: ...@@ -1486,7 +1491,7 @@ class TestRotate:
], ],
) )
def test_dispatcher_signature(self, kernel, input_type): def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_signatures_match(F.rotate, kernel=kernel, input_type=input_type) check_dispatcher_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"make_input", "make_input",
...@@ -1899,6 +1904,56 @@ class TestToDtype: ...@@ -1899,6 +1904,56 @@ class TestToDtype:
assert out["mask"].dtype == mask_dtype assert out["mask"].dtype == mask_dtype
class TestAdjustBrightness:
_CORRECTNESS_BRIGHTNESS_FACTORS = [0.5, 0.0, 1.0, 5.0]
_DEFAULT_BRIGHTNESS_FACTOR = _CORRECTNESS_BRIGHTNESS_FACTORS[0]
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.adjust_brightness_image_tensor, make_image),
(F.adjust_brightness_video, make_video),
],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel(self, kernel, make_input, dtype, device):
check_kernel(kernel, make_input(dtype=dtype, device=device), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.adjust_brightness_image_tensor, make_image_tensor),
(F.adjust_brightness_image_pil, make_image_pil),
(F.adjust_brightness_image_tensor, make_image),
(F.adjust_brightness_video, make_video),
],
)
def test_dispatcher(self, kernel, make_input):
check_dispatcher(F.adjust_brightness, kernel, make_input(), brightness_factor=self._DEFAULT_BRIGHTNESS_FACTOR)
@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.adjust_brightness_image_tensor, torch.Tensor),
(F.adjust_brightness_image_pil, PIL.Image.Image),
(F.adjust_brightness_image_tensor, datapoints.Image),
(F.adjust_brightness_video, datapoints.Video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type)
@pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS)
def test_image_correctness(self, brightness_factor):
image = make_image(dtype=torch.uint8, device="cpu")
actual = F.adjust_brightness(image, brightness_factor=brightness_factor)
expected = F.to_image_tensor(F.adjust_brightness(F.to_image_pil(image), brightness_factor=brightness_factor))
torch.testing.assert_close(actual, expected)
class TestCutMixMixUp: class TestCutMixMixUp:
class DummyDataset: class DummyDataset:
def __init__(self, size, num_classes): def __init__(self, size, num_classes):
...@@ -2036,3 +2091,93 @@ def test_labels_getter_default_heuristic(key, sample_type): ...@@ -2036,3 +2091,93 @@ def test_labels_getter_default_heuristic(key, sample_type):
# it takes precedence over other keys which would otherwise be a match # it takes precedence over other keys which would otherwise be a match
d = {key: "something_else", "labels": labels} d = {key: "something_else", "labels": labels}
assert transforms._utils._find_labels_default_heuristic(d) is labels assert transforms._utils._find_labels_default_heuristic(d) is labels
class TestShapeGetters:
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.get_dimensions_image_tensor, make_image_tensor),
(F.get_dimensions_image_pil, make_image_pil),
(F.get_dimensions_image_tensor, make_image),
(F.get_dimensions_video, make_video),
],
)
def test_get_dimensions(self, kernel, make_input):
size = (10, 10)
color_space, num_channels = "RGB", 3
input = make_input(size, color_space=color_space)
assert kernel(input) == F.get_dimensions(input) == [num_channels, *size]
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.get_num_channels_image_tensor, make_image_tensor),
(F.get_num_channels_image_pil, make_image_pil),
(F.get_num_channels_image_tensor, make_image),
(F.get_num_channels_video, make_video),
],
)
def test_get_num_channels(self, kernel, make_input):
color_space, num_channels = "RGB", 3
input = make_input(color_space=color_space)
assert kernel(input) == F.get_num_channels(input) == num_channels
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.get_size_image_tensor, make_image_tensor),
(F.get_size_image_pil, make_image_pil),
(F.get_size_image_tensor, make_image),
(F.get_size_bounding_boxes, make_bounding_box),
(F.get_size_mask, make_detection_mask),
(F.get_size_mask, make_segmentation_mask),
(F.get_size_video, make_video),
],
)
def test_get_size(self, kernel, make_input):
size = (10, 10)
input = make_input(size)
assert kernel(input) == F.get_size(input) == list(size)
@pytest.mark.parametrize(
("kernel", "make_input"),
[
(F.get_num_frames_video, make_video_tensor),
(F.get_num_frames_video, make_video),
],
)
def test_get_num_frames(self, kernel, make_input):
num_frames = 4
input = make_input(num_frames=num_frames)
assert kernel(input) == F.get_num_frames(input) == num_frames
@pytest.mark.parametrize(
("dispatcher", "make_input"),
[
(F.get_dimensions, make_bounding_box),
(F.get_dimensions, make_detection_mask),
(F.get_dimensions, make_segmentation_mask),
(F.get_num_channels, make_bounding_box),
(F.get_num_channels, make_detection_mask),
(F.get_num_channels, make_segmentation_mask),
(F.get_num_frames, make_image_pil),
(F.get_num_frames, make_image),
(F.get_num_frames, make_bounding_box),
(F.get_num_frames, make_detection_mask),
(F.get_num_frames, make_segmentation_mask),
],
)
def test_unsupported_types(self, dispatcher, make_input):
input = make_input()
with pytest.raises(TypeError, match=re.escape(str(type(input)))):
dispatcher(input)
...@@ -69,14 +69,15 @@ class DispatcherInfo(InfoBase): ...@@ -69,14 +69,15 @@ class DispatcherInfo(InfoBase):
import itertools import itertools
for args_kwargs in sample_inputs: for args_kwargs in sample_inputs:
for name in itertools.chain( if hasattr(datapoint_type, "__annotations__"):
datapoint_type.__annotations__.keys(), for name in itertools.chain(
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a datapoint_type.__annotations__.keys(),
# per-dispatcher level. However, so far there is no option for that. # FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
(f"old_{name}" for name in datapoint_type.__annotations__.keys()), # per-dispatcher level. However, so far there is no option for that.
): (f"old_{name}" for name in datapoint_type.__annotations__.keys()),
if name in args_kwargs.kwargs: ):
del args_kwargs.kwargs[name] if name in args_kwargs.kwargs:
del args_kwargs.kwargs[name]
yield args_kwargs yield args_kwargs
...@@ -289,14 +290,6 @@ DISPATCHER_INFOS = [ ...@@ -289,14 +290,6 @@ DISPATCHER_INFOS = [
skip_dispatch_datapoint, skip_dispatch_datapoint,
], ],
), ),
DispatcherInfo(
F.adjust_brightness,
kernels={
datapoints.Image: F.adjust_brightness_image_tensor,
datapoints.Video: F.adjust_brightness_video,
},
pil_kernel_info=PILKernelInfo(F.adjust_brightness_image_pil, kernel_name="adjust_brightness_image_pil"),
),
DispatcherInfo( DispatcherInfo(
F.adjust_contrast, F.adjust_contrast,
kernels={ kernels={
......
...@@ -1259,46 +1259,6 @@ KERNEL_INFOS.extend( ...@@ -1259,46 +1259,6 @@ KERNEL_INFOS.extend(
] ]
) )
_ADJUST_BRIGHTNESS_FACTORS = [0.1, 0.5]
def sample_inputs_adjust_brightness_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])
def reference_inputs_adjust_brightness_image_tensor():
for image_loader, brightness_factor in itertools.product(
make_image_loaders(color_spaces=("GRAY", "RGB"), extra_dims=[()], dtypes=[torch.uint8]),
_ADJUST_BRIGHTNESS_FACTORS,
):
yield ArgsKwargs(image_loader, brightness_factor=brightness_factor)
def sample_inputs_adjust_brightness_video():
for video_loader in make_video_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], num_frames=[3]):
yield ArgsKwargs(video_loader, brightness_factor=_ADJUST_BRIGHTNESS_FACTORS[0])
KERNEL_INFOS.extend(
[
KernelInfo(
F.adjust_brightness_image_tensor,
kernel_name="adjust_brightness_image_tensor",
sample_inputs_fn=sample_inputs_adjust_brightness_image_tensor,
reference_fn=pil_reference_wrapper(F.adjust_brightness_image_pil),
reference_inputs_fn=reference_inputs_adjust_brightness_image_tensor,
float32_vs_uint8=True,
closeness_kwargs=float32_vs_uint8_pixel_difference(),
),
KernelInfo(
F.adjust_brightness_video,
sample_inputs_fn=sample_inputs_adjust_brightness_video,
),
]
)
_ADJUST_CONTRAST_FACTORS = [0.1, 0.5] _ADJUST_CONTRAST_FACTORS = [0.1, 0.5]
......
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
from ._bounding_box import BoundingBoxes, BoundingBoxFormat from ._bounding_box import BoundingBoxes, BoundingBoxFormat
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT, Datapoint
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
from ._mask import Mask from ._mask import Mask
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video
......
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, Optional, Tuple, Union
import torch import torch
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
from ._datapoint import _FillTypeJIT, Datapoint from ._datapoint import Datapoint
class BoundingBoxFormat(Enum): class BoundingBoxFormat(Enum):
...@@ -97,141 +96,3 @@ class BoundingBoxes(Datapoint): ...@@ -97,141 +96,3 @@ class BoundingBoxes(Datapoint):
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr(format=self.format, canvas_size=self.canvas_size) return self._make_repr(format=self.format, canvas_size=self.canvas_size)
def horizontal_flip(self) -> BoundingBoxes:
output = self._F.horizontal_flip_bounding_boxes(
self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size
)
return BoundingBoxes.wrap_like(self, output)
def vertical_flip(self) -> BoundingBoxes:
output = self._F.vertical_flip_bounding_boxes(
self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size
)
return BoundingBoxes.wrap_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBoxes:
output, canvas_size = self._F.resize_bounding_boxes(
self.as_subclass(torch.Tensor),
canvas_size=self.canvas_size,
size=size,
max_size=max_size,
)
return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBoxes:
output, canvas_size = self._F.crop_bounding_boxes(
self.as_subclass(torch.Tensor), self.format, top=top, left=left, height=height, width=width
)
return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def center_crop(self, output_size: List[int]) -> BoundingBoxes:
output, canvas_size = self._F.center_crop_bounding_boxes(
self.as_subclass(torch.Tensor), format=self.format, canvas_size=self.canvas_size, output_size=output_size
)
return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBoxes:
output, canvas_size = self._F.resized_crop_bounding_boxes(
self.as_subclass(torch.Tensor), self.format, top, left, height, width, size=size
)
return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def pad(
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> BoundingBoxes:
output, canvas_size = self._F.pad_bounding_boxes(
self.as_subclass(torch.Tensor),
format=self.format,
canvas_size=self.canvas_size,
padding=padding,
padding_mode=padding_mode,
)
return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def rotate(
self,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: _FillTypeJIT = None,
) -> BoundingBoxes:
output, canvas_size = self._F.rotate_bounding_boxes(
self.as_subclass(torch.Tensor),
format=self.format,
canvas_size=self.canvas_size,
angle=angle,
expand=expand,
center=center,
)
return BoundingBoxes.wrap_like(self, output, canvas_size=canvas_size)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: _FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBoxes:
output = self._F.affine_bounding_boxes(
self.as_subclass(torch.Tensor),
self.format,
self.canvas_size,
angle,
translate=translate,
scale=scale,
shear=shear,
center=center,
)
return BoundingBoxes.wrap_like(self, output)
def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> BoundingBoxes:
output = self._F.perspective_bounding_boxes(
self.as_subclass(torch.Tensor),
format=self.format,
canvas_size=self.canvas_size,
startpoints=startpoints,
endpoints=endpoints,
coefficients=coefficients,
)
return BoundingBoxes.wrap_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
) -> BoundingBoxes:
output = self._F.elastic_bounding_boxes(
self.as_subclass(torch.Tensor), self.format, self.canvas_size, displacement=displacement
)
return BoundingBoxes.wrap_like(self, output)
from __future__ import annotations from __future__ import annotations
from types import ModuleType
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image import PIL.Image
import torch import torch
from torch._C import DisableTorchFunctionSubclass from torch._C import DisableTorchFunctionSubclass
from torch.types import _device, _dtype, _size from torch.types import _device, _dtype, _size
from torchvision.transforms import InterpolationMode
D = TypeVar("D", bound="Datapoint") D = TypeVar("D", bound="Datapoint")
...@@ -16,8 +14,6 @@ _FillTypeJIT = Optional[List[float]] ...@@ -16,8 +14,6 @@ _FillTypeJIT = Optional[List[float]]
class Datapoint(torch.Tensor): class Datapoint(torch.Tensor):
__F: Optional[ModuleType] = None
@staticmethod @staticmethod
def _to_tensor( def _to_tensor(
data: Any, data: Any,
...@@ -99,18 +95,6 @@ class Datapoint(torch.Tensor): ...@@ -99,18 +95,6 @@ class Datapoint(torch.Tensor):
extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items()) extra_repr = ", ".join(f"{key}={value}" for key, value in kwargs.items())
return f"{super().__repr__()[:-1]}, {extra_repr})" return f"{super().__repr__()[:-1]}, {extra_repr})"
@property
def _F(self) -> ModuleType:
# This implements a lazy import of the functional to get around the cyclic import. This import is deferred
# until the first time we need reference to the functional module and it's shared across all instances of
# the class. This approach avoids the DataLoader issue described at
# https://github.com/pytorch/vision/pull/6476#discussion_r953588621
if Datapoint.__F is None:
from ..transforms.v2 import functional
Datapoint.__F = functional
return Datapoint.__F
# Add properties for common attributes like shape, dtype, device, ndim etc # Add properties for common attributes like shape, dtype, device, ndim etc
# this way we return the result without passing into __torch_function__ # this way we return the result without passing into __torch_function__
@property @property
...@@ -142,128 +126,6 @@ class Datapoint(torch.Tensor): ...@@ -142,128 +126,6 @@ class Datapoint(torch.Tensor):
# `BoundingBoxes.clone()`. # `BoundingBoxes.clone()`.
return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value] return self.detach().clone().requires_grad_(self.requires_grad) # type: ignore[return-value]
def horizontal_flip(self) -> Datapoint:
return self
def vertical_flip(self) -> Datapoint:
return self
# TODO: We have to ignore override mypy error as there is torch.Tensor built-in deprecated op: Tensor.resize
# https://github.com/pytorch/pytorch/blob/e8727994eb7cdb2ab642749d6549bc497563aa06/torch/_tensor.py#L588-L593
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
return self
def crop(self, top: int, left: int, height: int, width: int) -> Datapoint:
return self
def center_crop(self, output_size: List[int]) -> Datapoint:
return self
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
return self
def pad(
self,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Datapoint:
return self
def rotate(
self,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: _FillTypeJIT = None,
) -> Datapoint:
return self
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: _FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Datapoint:
return self
def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Datapoint:
return self
def elastic(
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
) -> Datapoint:
return self
def rgb_to_grayscale(self, num_output_channels: int = 1) -> Datapoint:
return self
def adjust_brightness(self, brightness_factor: float) -> Datapoint:
return self
def adjust_saturation(self, saturation_factor: float) -> Datapoint:
return self
def adjust_contrast(self, contrast_factor: float) -> Datapoint:
return self
def adjust_sharpness(self, sharpness_factor: float) -> Datapoint:
return self
def adjust_hue(self, hue_factor: float) -> Datapoint:
return self
def adjust_gamma(self, gamma: float, gain: float = 1) -> Datapoint:
return self
def posterize(self, bits: int) -> Datapoint:
return self
def solarize(self, threshold: float) -> Datapoint:
return self
def autocontrast(self) -> Datapoint:
return self
def equalize(self) -> Datapoint:
return self
def invert(self) -> Datapoint:
return self
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Datapoint:
return self
_InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint] _InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
_InputTypeJIT = torch.Tensor _InputTypeJIT = torch.Tensor
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Union from typing import Any, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.transforms.functional import InterpolationMode
from ._datapoint import _FillTypeJIT, Datapoint from ._datapoint import Datapoint
class Image(Datapoint): class Image(Datapoint):
...@@ -56,195 +55,6 @@ class Image(Datapoint): ...@@ -56,195 +55,6 @@ class Image(Datapoint):
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr() return self._make_repr()
def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def vertical_flip(self) -> Image:
output = self._F.vertical_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
output = self._F.resize_image_tensor(
self.as_subclass(torch.Tensor), size, interpolation=interpolation, max_size=max_size, antialias=antialias
)
return Image.wrap_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Image:
output = self._F.crop_image_tensor(self.as_subclass(torch.Tensor), top, left, height, width)
return Image.wrap_like(self, output)
def center_crop(self, output_size: List[int]) -> Image:
output = self._F.center_crop_image_tensor(self.as_subclass(torch.Tensor), output_size=output_size)
return Image.wrap_like(self, output)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
output = self._F.resized_crop_image_tensor(
self.as_subclass(torch.Tensor),
top,
left,
height,
width,
size=list(size),
interpolation=interpolation,
antialias=antialias,
)
return Image.wrap_like(self, output)
def pad(
self,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Image:
output = self._F.pad_image_tensor(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
return Image.wrap_like(self, output)
def rotate(
self,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: _FillTypeJIT = None,
) -> Image:
output = self._F.rotate_image_tensor(
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
return Image.wrap_like(self, output)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: _FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
output = self._F.affine_image_tensor(
self.as_subclass(torch.Tensor),
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
return Image.wrap_like(self, output)
def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Image:
output = self._F.perspective_image_tensor(
self.as_subclass(torch.Tensor),
startpoints,
endpoints,
interpolation=interpolation,
fill=fill,
coefficients=coefficients,
)
return Image.wrap_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
) -> Image:
output = self._F.elastic_image_tensor(
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
)
return Image.wrap_like(self, output)
def rgb_to_grayscale(self, num_output_channels: int = 1) -> Image:
output = self._F.rgb_to_grayscale_image_tensor(
self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
)
return Image.wrap_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Image:
output = self._F.adjust_brightness_image_tensor(
self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
)
return Image.wrap_like(self, output)
def adjust_saturation(self, saturation_factor: float) -> Image:
output = self._F.adjust_saturation_image_tensor(
self.as_subclass(torch.Tensor), saturation_factor=saturation_factor
)
return Image.wrap_like(self, output)
def adjust_contrast(self, contrast_factor: float) -> Image:
output = self._F.adjust_contrast_image_tensor(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
return Image.wrap_like(self, output)
def adjust_sharpness(self, sharpness_factor: float) -> Image:
output = self._F.adjust_sharpness_image_tensor(
self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor
)
return Image.wrap_like(self, output)
def adjust_hue(self, hue_factor: float) -> Image:
output = self._F.adjust_hue_image_tensor(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
return Image.wrap_like(self, output)
def adjust_gamma(self, gamma: float, gain: float = 1) -> Image:
output = self._F.adjust_gamma_image_tensor(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
return Image.wrap_like(self, output)
def posterize(self, bits: int) -> Image:
output = self._F.posterize_image_tensor(self.as_subclass(torch.Tensor), bits=bits)
return Image.wrap_like(self, output)
def solarize(self, threshold: float) -> Image:
output = self._F.solarize_image_tensor(self.as_subclass(torch.Tensor), threshold=threshold)
return Image.wrap_like(self, output)
def autocontrast(self) -> Image:
output = self._F.autocontrast_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def equalize(self) -> Image:
output = self._F.equalize_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def invert(self) -> Image:
output = self._F.invert_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output)
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
output = self._F.gaussian_blur_image_tensor(
self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma
)
return Image.wrap_like(self, output)
def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image:
output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
return Image.wrap_like(self, output)
_ImageType = Union[torch.Tensor, PIL.Image.Image, Image] _ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
_ImageTypeJIT = torch.Tensor _ImageTypeJIT = torch.Tensor
......
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Union from typing import Any, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.transforms import InterpolationMode
from ._datapoint import _FillTypeJIT, Datapoint from ._datapoint import Datapoint
class Mask(Datapoint): class Mask(Datapoint):
...@@ -50,105 +49,3 @@ class Mask(Datapoint): ...@@ -50,105 +49,3 @@ class Mask(Datapoint):
tensor: torch.Tensor, tensor: torch.Tensor,
) -> Mask: ) -> Mask:
return cls._wrap(tensor) return cls._wrap(tensor)
def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_mask(self.as_subclass(torch.Tensor))
return Mask.wrap_like(self, output)
def vertical_flip(self) -> Mask:
output = self._F.vertical_flip_mask(self.as_subclass(torch.Tensor))
return Mask.wrap_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
output = self._F.resize_mask(self.as_subclass(torch.Tensor), size, max_size=max_size)
return Mask.wrap_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Mask:
output = self._F.crop_mask(self.as_subclass(torch.Tensor), top, left, height, width)
return Mask.wrap_like(self, output)
def center_crop(self, output_size: List[int]) -> Mask:
output = self._F.center_crop_mask(self.as_subclass(torch.Tensor), output_size=output_size)
return Mask.wrap_like(self, output)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
return Mask.wrap_like(self, output)
def pad(
self,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Mask:
output = self._F.pad_mask(self.as_subclass(torch.Tensor), padding, padding_mode=padding_mode, fill=fill)
return Mask.wrap_like(self, output)
def rotate(
self,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: _FillTypeJIT = None,
) -> Mask:
output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
return Mask.wrap_like(self, output)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: _FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
output = self._F.affine_mask(
self.as_subclass(torch.Tensor),
angle,
translate=translate,
scale=scale,
shear=shear,
fill=fill,
center=center,
)
return Mask.wrap_like(self, output)
def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Mask:
output = self._F.perspective_mask(
self.as_subclass(torch.Tensor), startpoints, endpoints, fill=fill, coefficients=coefficients
)
return Mask.wrap_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: _FillTypeJIT = None,
) -> Mask:
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
return Mask.wrap_like(self, output)
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Union from typing import Any, Optional, Union
import torch import torch
from torchvision.transforms.functional import InterpolationMode
from ._datapoint import _FillTypeJIT, Datapoint from ._datapoint import Datapoint
class Video(Datapoint): class Video(Datapoint):
...@@ -46,191 +45,6 @@ class Video(Datapoint): ...@@ -46,191 +45,6 @@ class Video(Datapoint):
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
return self._make_repr() return self._make_repr()
def horizontal_flip(self) -> Video:
output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output)
def vertical_flip(self) -> Video:
output = self._F.vertical_flip_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output)
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Video:
output = self._F.resize_video(
self.as_subclass(torch.Tensor),
size,
interpolation=interpolation,
max_size=max_size,
antialias=antialias,
)
return Video.wrap_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Video:
output = self._F.crop_video(self.as_subclass(torch.Tensor), top, left, height, width)
return Video.wrap_like(self, output)
def center_crop(self, output_size: List[int]) -> Video:
output = self._F.center_crop_video(self.as_subclass(torch.Tensor), output_size=output_size)
return Video.wrap_like(self, output)
def resized_crop(
self,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Video:
output = self._F.resized_crop_video(
self.as_subclass(torch.Tensor),
top,
left,
height,
width,
size=list(size),
interpolation=interpolation,
antialias=antialias,
)
return Video.wrap_like(self, output)
def pad(
self,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> Video:
output = self._F.pad_video(self.as_subclass(torch.Tensor), padding, fill=fill, padding_mode=padding_mode)
return Video.wrap_like(self, output)
def rotate(
self,
angle: float,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: _FillTypeJIT = None,
) -> Video:
output = self._F.rotate_video(
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center
)
return Video.wrap_like(self, output)
def affine(
self,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: _FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Video:
output = self._F.affine_video(
self.as_subclass(torch.Tensor),
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
return Video.wrap_like(self, output)
def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Video:
output = self._F.perspective_video(
self.as_subclass(torch.Tensor),
startpoints,
endpoints,
interpolation=interpolation,
fill=fill,
coefficients=coefficients,
)
return Video.wrap_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: _FillTypeJIT = None,
) -> Video:
output = self._F.elastic_video(
self.as_subclass(torch.Tensor), displacement, interpolation=interpolation, fill=fill
)
return Video.wrap_like(self, output)
def rgb_to_grayscale(self, num_output_channels: int = 1) -> Video:
output = self._F.rgb_to_grayscale_image_tensor(
self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
)
return Video.wrap_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Video:
output = self._F.adjust_brightness_video(self.as_subclass(torch.Tensor), brightness_factor=brightness_factor)
return Video.wrap_like(self, output)
def adjust_saturation(self, saturation_factor: float) -> Video:
output = self._F.adjust_saturation_video(self.as_subclass(torch.Tensor), saturation_factor=saturation_factor)
return Video.wrap_like(self, output)
def adjust_contrast(self, contrast_factor: float) -> Video:
output = self._F.adjust_contrast_video(self.as_subclass(torch.Tensor), contrast_factor=contrast_factor)
return Video.wrap_like(self, output)
def adjust_sharpness(self, sharpness_factor: float) -> Video:
output = self._F.adjust_sharpness_video(self.as_subclass(torch.Tensor), sharpness_factor=sharpness_factor)
return Video.wrap_like(self, output)
def adjust_hue(self, hue_factor: float) -> Video:
output = self._F.adjust_hue_video(self.as_subclass(torch.Tensor), hue_factor=hue_factor)
return Video.wrap_like(self, output)
def adjust_gamma(self, gamma: float, gain: float = 1) -> Video:
output = self._F.adjust_gamma_video(self.as_subclass(torch.Tensor), gamma=gamma, gain=gain)
return Video.wrap_like(self, output)
def posterize(self, bits: int) -> Video:
output = self._F.posterize_video(self.as_subclass(torch.Tensor), bits=bits)
return Video.wrap_like(self, output)
def solarize(self, threshold: float) -> Video:
output = self._F.solarize_video(self.as_subclass(torch.Tensor), threshold=threshold)
return Video.wrap_like(self, output)
def autocontrast(self) -> Video:
output = self._F.autocontrast_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output)
def equalize(self) -> Video:
output = self._F.equalize_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output)
def invert(self) -> Video:
output = self._F.invert_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output)
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Video:
output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma)
return Video.wrap_like(self, output)
def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video:
output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
return Video.wrap_like(self, output)
_VideoType = Union[torch.Tensor, Video] _VideoType = Union[torch.Tensor, Video]
_VideoTypeJIT = torch.Tensor _VideoTypeJIT = torch.Tensor
......
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple
import PIL.Image import PIL.Image
import torch import torch
...@@ -56,8 +56,6 @@ class RandomErasing(_RandomApplyTransform): ...@@ -56,8 +56,6 @@ class RandomErasing(_RandomApplyTransform):
value="random" if self.value is None else self.value, value="random" if self.value is None else self.value,
) )
_transformed_types = (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)
def __init__( def __init__(
self, self,
p: float = 0.5, p: float = 0.5,
...@@ -131,9 +129,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -131,9 +129,7 @@ class RandomErasing(_RandomApplyTransform):
return dict(i=i, j=j, h=h, w=w, v=v) return dict(i=i, j=j, h=h, w=w, v=v)
def _transform( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Union[datapoints._ImageType, datapoints._VideoType]:
if params["v"] is not None: if params["v"] is not None:
inpt = F.erase(inpt, **params, inplace=self.inplace) inpt = F.erase(inpt, **params, inplace=self.inplace)
......
...@@ -355,20 +355,11 @@ class FiveCrop(Transform): ...@@ -355,20 +355,11 @@ class FiveCrop(Transform):
_v1_transform_cls = _transforms.FiveCrop _v1_transform_cls = _transforms.FiveCrop
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]]) -> None: def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
def _transform( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
self, inpt: ImageOrVideoTypeJIT, params: Dict[str, Any]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
return F.five_crop(inpt, self.size) return F.five_crop(inpt, self.size)
def _check_inputs(self, flat_inputs: List[Any]) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
...@@ -402,13 +393,6 @@ class TenCrop(Transform): ...@@ -402,13 +393,6 @@ class TenCrop(Transform):
_v1_transform_cls = _transforms.TenCrop _v1_transform_cls = _transforms.TenCrop
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__() super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
...@@ -418,20 +402,7 @@ class TenCrop(Transform): ...@@ -418,20 +402,7 @@ class TenCrop(Transform):
if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask):
raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()")
def _transform( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any]
) -> Tuple[
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
......
from typing import Any, Dict from typing import Any, Dict
import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.utils import is_simple_tensor
class UniformTemporalSubsample(Transform): class UniformTemporalSubsample(Transform):
"""[BETA] Uniformly subsample ``num_samples`` indices from the temporal dimension of the video. """[BETA] Uniformly subsample ``num_samples`` indices from the temporal dimension of the video.
...@@ -20,7 +19,7 @@ class UniformTemporalSubsample(Transform): ...@@ -20,7 +19,7 @@ class UniformTemporalSubsample(Transform):
num_samples (int): The number of equispaced samples to be selected num_samples (int): The number of equispaced samples to be selected
""" """
_transformed_types = (is_simple_tensor, datapoints.Video) _transformed_types = (torch.Tensor,)
def __init__(self, num_samples: int): def __init__(self, num_samples: int):
super().__init__() super().__init__()
......
from torchvision.transforms import InterpolationMode # usort: skip from torchvision.transforms import InterpolationMode # usort: skip
from ._utils import is_simple_tensor # usort: skip from ._utils import is_simple_tensor, register_kernel # usort: skip
from ._meta import ( from ._meta import (
clamp_bounding_boxes, clamp_bounding_boxes,
convert_format_bounding_boxes, convert_format_bounding_boxes,
get_dimensions_image_tensor, get_dimensions_image_tensor,
get_dimensions_image_pil, get_dimensions_image_pil,
get_dimensions_video,
get_dimensions, get_dimensions,
get_num_frames_video, get_num_frames_video,
get_num_frames, get_num_frames,
......
...@@ -7,9 +7,37 @@ from torchvision import datapoints ...@@ -7,9 +7,37 @@ from torchvision import datapoints
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._utils import is_simple_tensor from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor
@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True)
def erase(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT],
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(erase)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(erase, type(inpt))
return kernel(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(erase, datapoints.Image)
def erase_image_tensor( def erase_image_tensor(
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -29,36 +57,8 @@ def erase_image_pil( ...@@ -29,36 +57,8 @@ def erase_image_pil(
return to_pil_image(output, mode=image.mode) return to_pil_image(output, mode=image.mode)
@_register_kernel_internal(erase, datapoints.Video)
def erase_video( def erase_video(
video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
def erase(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT],
i: int,
j: int,
h: int,
w: int,
v: torch.Tensor,
inplace: bool = False,
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(erase)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, datapoints.Image):
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return datapoints.Image.wrap_like(inpt, output)
elif isinstance(inpt, datapoints.Video):
output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return datapoints.Video.wrap_like(inpt, output)
elif isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
...@@ -10,7 +10,34 @@ from torchvision.transforms._functional_tensor import _max_value ...@@ -10,7 +10,34 @@ from torchvision.transforms._functional_tensor import _max_value
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._misc import _num_value_bits, to_dtype_image_tensor from ._misc import _num_value_bits, to_dtype_image_tensor
from ._utils import is_simple_tensor from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, is_simple_tensor
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video)
def rgb_to_grayscale(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(rgb_to_grayscale)
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(rgb_to_grayscale, type(inpt))
return kernel(inpt, num_output_channels=num_output_channels)
elif isinstance(inpt, PIL.Image.Image):
return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
# `to_grayscale` actually predates `rgb_to_grayscale` in v1, but only handles PIL images. Since `rgb_to_grayscale` is a
# superset in terms of functionality and has the same signature, we alias here to avoid disruption.
to_grayscale = rgb_to_grayscale
def _rgb_to_grayscale_image_tensor( def _rgb_to_grayscale_image_tensor(
...@@ -29,6 +56,7 @@ def _rgb_to_grayscale_image_tensor( ...@@ -29,6 +56,7 @@ def _rgb_to_grayscale_image_tensor(
return l_img return l_img
@_register_kernel_internal(rgb_to_grayscale, datapoints.Image)
def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True) return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True)
...@@ -36,19 +64,26 @@ def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int ...@@ -36,19 +64,26 @@ def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int
rgb_to_grayscale_image_pil = _FP.to_grayscale rgb_to_grayscale_image_pil = _FP.to_grayscale
def rgb_to_grayscale( def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1 ratio = float(ratio)
) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: fp = image1.is_floating_point()
bound = _max_value(image1.dtype)
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
return output if fp else output.to(image1.dtype)
@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(rgb_to_grayscale) _log_api_usage_once(adjust_brightness)
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.rgb_to_grayscale(num_output_channels=num_output_channels) kernel = _get_kernel(adjust_brightness, type(inpt))
return kernel(inpt, brightness_factor=brightness_factor)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels) return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -56,19 +91,7 @@ def rgb_to_grayscale( ...@@ -56,19 +91,7 @@ def rgb_to_grayscale(
) )
# `to_grayscale` actually predates `rgb_to_grayscale` in v1, but only handles PIL images. Since `rgb_to_grayscale` is a @_register_kernel_internal(adjust_brightness, datapoints.Image)
# superset in terms of functionality and has the same signature, we alias here to avoid disruption.
to_grayscale = rgb_to_grayscale
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
bound = _max_value(image1.dtype)
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
return output if fp else output.to(image1.dtype)
def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor:
if brightness_factor < 0: if brightness_factor < 0:
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
...@@ -83,23 +106,27 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float ...@@ -83,23 +106,27 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
return output if fp else output.to(image.dtype) return output if fp else output.to(image.dtype)
adjust_brightness_image_pil = _FP.adjust_brightness def adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image:
return _FP.adjust_brightness(image, brightness_factor=brightness_factor)
@_register_kernel_internal(adjust_brightness, datapoints.Video)
def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor: def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor:
return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor)
def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_brightness) _log_api_usage_once(adjust_saturation)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.adjust_brightness(brightness_factor=brightness_factor) kernel = _get_kernel(adjust_saturation, type(inpt))
return kernel(inpt, saturation_factor=saturation_factor)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -107,6 +134,7 @@ def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) ...@@ -107,6 +134,7 @@ def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float)
) )
@_register_kernel_internal(adjust_saturation, datapoints.Image)
def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor:
if saturation_factor < 0: if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
...@@ -128,22 +156,23 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float ...@@ -128,22 +156,23 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
adjust_saturation_image_pil = _FP.adjust_saturation adjust_saturation_image_pil = _FP.adjust_saturation
@_register_kernel_internal(adjust_saturation, datapoints.Video)
def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor: def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor:
return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor)
def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_saturation) _log_api_usage_once(adjust_contrast)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor)
): elif isinstance(inpt, datapoints.Datapoint):
return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) kernel = _get_kernel(adjust_contrast, type(inpt))
elif isinstance(inpt, datapoints._datapoint.Datapoint): return kernel(inpt, contrast_factor=contrast_factor)
return inpt.adjust_saturation(saturation_factor=saturation_factor)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -151,6 +180,7 @@ def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) ...@@ -151,6 +180,7 @@ def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float)
) )
@_register_kernel_internal(adjust_contrast, datapoints.Image)
def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor:
if contrast_factor < 0: if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
...@@ -172,20 +202,23 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> ...@@ -172,20 +202,23 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
adjust_contrast_image_pil = _FP.adjust_contrast adjust_contrast_image_pil = _FP.adjust_contrast
@_register_kernel_internal(adjust_contrast, datapoints.Video)
def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor: def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor:
return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor)
def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_contrast) _log_api_usage_once(adjust_sharpness)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, datapoints.Datapoint)):
return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.adjust_contrast(contrast_factor=contrast_factor) kernel = _get_kernel(adjust_sharpness, type(inpt))
return kernel(inpt, sharpness_factor=sharpness_factor)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -193,6 +226,7 @@ def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> d ...@@ -193,6 +226,7 @@ def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> d
) )
@_register_kernel_internal(adjust_sharpness, datapoints.Image)
def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = image.shape[-3:] num_channels, height, width = image.shape[-3:]
if num_channels not in (1, 3): if num_channels not in (1, 3):
...@@ -248,22 +282,23 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) ...@@ -248,22 +282,23 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
adjust_sharpness_image_pil = _FP.adjust_sharpness adjust_sharpness_image_pil = _FP.adjust_sharpness
@_register_kernel_internal(adjust_sharpness, datapoints.Video)
def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_sharpness) _log_api_usage_once(adjust_hue)
if isinstance(inpt, torch.Tensor) and ( if torch.jit.is_scripting() or is_simple_tensor(inpt):
torch.jit.is_scripting() or not isinstance(inpt, datapoints._datapoint.Datapoint) return adjust_hue_image_tensor(inpt, hue_factor=hue_factor)
): elif isinstance(inpt, datapoints.Datapoint):
return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) kernel = _get_kernel(adjust_hue, type(inpt))
elif isinstance(inpt, datapoints._datapoint.Datapoint): return kernel(inpt, hue_factor=hue_factor)
return inpt.adjust_sharpness(sharpness_factor=sharpness_factor)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -335,6 +370,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: ...@@ -335,6 +370,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor:
return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3) return (a4.mul_(mask.unsqueeze(dim=-4))).sum(dim=-3)
@_register_kernel_internal(adjust_hue, datapoints.Image)
def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor:
if not (-0.5 <= hue_factor <= 0.5): if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
...@@ -365,20 +401,23 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten ...@@ -365,20 +401,23 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
adjust_hue_image_pil = _FP.adjust_hue adjust_hue_image_pil = _FP.adjust_hue
@_register_kernel_internal(adjust_hue, datapoints.Video)
def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor:
return adjust_hue_image_tensor(video, hue_factor=hue_factor) return adjust_hue_image_tensor(video, hue_factor=hue_factor)
def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_hue) _log_api_usage_once(adjust_gamma)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.adjust_hue(hue_factor=hue_factor) kernel = _get_kernel(adjust_gamma, type(inpt))
return kernel(inpt, gamma=gamma, gain=gain)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_hue_image_pil(inpt, hue_factor=hue_factor) return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -386,6 +425,7 @@ def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints. ...@@ -386,6 +425,7 @@ def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints.
) )
@_register_kernel_internal(adjust_gamma, datapoints.Image)
def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
if gamma < 0: if gamma < 0:
raise ValueError("Gamma should be a non-negative real number") raise ValueError("Gamma should be a non-negative real number")
...@@ -408,20 +448,23 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 ...@@ -408,20 +448,23 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1
adjust_gamma_image_pil = _FP.adjust_gamma adjust_gamma_image_pil = _FP.adjust_gamma
@_register_kernel_internal(adjust_gamma, datapoints.Video)
def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor:
return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain)
def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(adjust_gamma) _log_api_usage_once(posterize)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) return posterize_image_tensor(inpt, bits=bits)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.adjust_gamma(gamma=gamma, gain=gain) kernel = _get_kernel(posterize, type(inpt))
return kernel(inpt, bits=bits)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) return posterize_image_pil(inpt, bits=bits)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -429,6 +472,7 @@ def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) ...@@ -429,6 +472,7 @@ def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1)
) )
@_register_kernel_internal(posterize, datapoints.Image)
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
levels = 1 << bits levels = 1 << bits
...@@ -445,20 +489,23 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: ...@@ -445,20 +489,23 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
posterize_image_pil = _FP.posterize posterize_image_pil = _FP.posterize
@_register_kernel_internal(posterize, datapoints.Video)
def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor:
return posterize_image_tensor(video, bits=bits) return posterize_image_tensor(video, bits=bits)
def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(posterize) _log_api_usage_once(solarize)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return posterize_image_tensor(inpt, bits=bits) return solarize_image_tensor(inpt, threshold=threshold)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.posterize(bits=bits) kernel = _get_kernel(solarize, type(inpt))
return kernel(inpt, threshold=threshold)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return posterize_image_pil(inpt, bits=bits) return solarize_image_pil(inpt, threshold=threshold)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -466,6 +513,7 @@ def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTyp ...@@ -466,6 +513,7 @@ def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTyp
) )
@_register_kernel_internal(solarize, datapoints.Image)
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
if threshold > _max_value(image.dtype): if threshold > _max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
...@@ -476,20 +524,25 @@ def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor ...@@ -476,20 +524,25 @@ def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor
solarize_image_pil = _FP.solarize solarize_image_pil = _FP.solarize
@_register_kernel_internal(solarize, datapoints.Video)
def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor:
return solarize_image_tensor(video, threshold=threshold) return solarize_image_tensor(video, threshold=threshold)
def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(solarize) _log_api_usage_once(autocontrast)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return solarize_image_tensor(inpt, threshold=threshold) return autocontrast_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.solarize(threshold=threshold) kernel = _get_kernel(autocontrast, type(inpt))
return kernel(
inpt,
)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return solarize_image_pil(inpt, threshold=threshold) return autocontrast_image_pil(inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -497,6 +550,7 @@ def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._In ...@@ -497,6 +550,7 @@ def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._In
) )
@_register_kernel_internal(autocontrast, datapoints.Image)
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
c = image.shape[-3] c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
...@@ -529,20 +583,25 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -529,20 +583,25 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
autocontrast_image_pil = _FP.autocontrast autocontrast_image_pil = _FP.autocontrast
@_register_kernel_internal(autocontrast, datapoints.Video)
def autocontrast_video(video: torch.Tensor) -> torch.Tensor: def autocontrast_video(video: torch.Tensor) -> torch.Tensor:
return autocontrast_image_tensor(video) return autocontrast_image_tensor(video)
def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(autocontrast) _log_api_usage_once(equalize)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return autocontrast_image_tensor(inpt) return equalize_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.autocontrast() kernel = _get_kernel(equalize, type(inpt))
return kernel(
inpt,
)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return autocontrast_image_pil(inpt) return equalize_image_pil(inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -550,6 +609,7 @@ def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: ...@@ -550,6 +609,7 @@ def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
) )
@_register_kernel_internal(equalize, datapoints.Image)
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0: if image.numel() == 0:
return image return image
...@@ -622,20 +682,25 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -622,20 +682,25 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
equalize_image_pil = _FP.equalize equalize_image_pil = _FP.equalize
@_register_kernel_internal(equalize, datapoints.Video)
def equalize_video(video: torch.Tensor) -> torch.Tensor: def equalize_video(video: torch.Tensor) -> torch.Tensor:
return equalize_image_tensor(video) return equalize_image_tensor(video)
def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask)
def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(equalize) _log_api_usage_once(invert)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return equalize_image_tensor(inpt) return invert_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.equalize() kernel = _get_kernel(invert, type(inpt))
return kernel(
inpt,
)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return equalize_image_pil(inpt) return invert_image_pil(inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -643,6 +708,7 @@ def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: ...@@ -643,6 +708,7 @@ def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
) )
@_register_kernel_internal(invert, datapoints.Image)
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.is_floating_point(): if image.is_floating_point():
return 1.0 - image return 1.0 - image
...@@ -656,22 +722,6 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: ...@@ -656,22 +722,6 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
invert_image_pil = _FP.invert invert_image_pil = _FP.invert
@_register_kernel_internal(invert, datapoints.Video)
def invert_video(video: torch.Tensor) -> torch.Tensor: def invert_video(video: torch.Tensor) -> torch.Tensor:
return invert_image_tensor(video) return invert_image_tensor(video)
def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(invert)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return invert_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.invert()
elif isinstance(inpt, PIL.Image.Image):
return invert_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
import math import math
import numbers import numbers
import warnings import warnings
from typing import List, Optional, Sequence, Tuple, Union from typing import Any, List, Optional, Sequence, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -25,7 +25,13 @@ from torchvision.utils import _log_api_usage_once ...@@ -25,7 +25,13 @@ from torchvision.utils import _log_api_usage_once
from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil
from ._utils import is_simple_tensor from ._utils import (
_get_kernel,
_register_explicit_noop,
_register_five_ten_crop_kernel,
_register_kernel_internal,
is_simple_tensor,
)
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
...@@ -39,6 +45,27 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp ...@@ -39,6 +45,27 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp
return interpolation return interpolation
def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(horizontal_flip)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return horizontal_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(horizontal_flip, type(inpt))
return kernel(
inpt,
)
elif isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(horizontal_flip, datapoints.Image)
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-1) return image.flip(-1)
...@@ -47,6 +74,7 @@ def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: ...@@ -47,6 +74,7 @@ def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.hflip(image) return _FP.hflip(image)
@_register_kernel_internal(horizontal_flip, datapoints.Mask)
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(mask) return horizontal_flip_image_tensor(mask)
...@@ -68,20 +96,32 @@ def horizontal_flip_bounding_boxes( ...@@ -68,20 +96,32 @@ def horizontal_flip_bounding_boxes(
return bounding_boxes.reshape(shape) return bounding_boxes.reshape(shape)
@_register_kernel_internal(horizontal_flip, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> datapoints.BoundingBoxes:
output = horizontal_flip_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
@_register_kernel_internal(horizontal_flip, datapoints.Video)
def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(video) return horizontal_flip_image_tensor(video)
def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(horizontal_flip) _log_api_usage_once(vertical_flip)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return horizontal_flip_image_tensor(inpt) return vertical_flip_image_tensor(inpt)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.horizontal_flip() kernel = _get_kernel(vertical_flip, type(inpt))
return kernel(
inpt,
)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return horizontal_flip_image_pil(inpt) return vertical_flip_image_pil(inpt)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -89,6 +129,7 @@ def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: ...@@ -89,6 +129,7 @@ def horizontal_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
) )
@_register_kernel_internal(vertical_flip, datapoints.Image)
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2) return image.flip(-2)
...@@ -97,6 +138,7 @@ def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image: ...@@ -97,6 +138,7 @@ def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
return _FP.vflip(image) return _FP.vflip(image)
@_register_kernel_internal(vertical_flip, datapoints.Mask)
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(mask) return vertical_flip_image_tensor(mask)
...@@ -118,25 +160,17 @@ def vertical_flip_bounding_boxes( ...@@ -118,25 +160,17 @@ def vertical_flip_bounding_boxes(
return bounding_boxes.reshape(shape) return bounding_boxes.reshape(shape)
def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(vertical_flip, datapoints.BoundingBoxes, datapoint_wrapper=False)
return vertical_flip_image_tensor(video) def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> datapoints.BoundingBoxes:
output = vertical_flip_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
def vertical_flip(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(vertical_flip)
if torch.jit.is_scripting() or is_simple_tensor(inpt): @_register_kernel_internal(vertical_flip, datapoints.Video)
return vertical_flip_image_tensor(inpt) def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
elif isinstance(inpt, datapoints._datapoint.Datapoint): return vertical_flip_image_tensor(video)
return inpt.vertical_flip()
elif isinstance(inpt, PIL.Image.Image):
return vertical_flip_image_pil(inpt)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
# We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are
...@@ -158,6 +192,32 @@ def _compute_resized_output_size( ...@@ -158,6 +192,32 @@ def _compute_resized_output_size(
return __compute_resized_output_size(canvas_size, size=size, max_size=max_size) return __compute_resized_output_size(canvas_size, size=size, max_size=max_size)
def resize(
inpt: datapoints._InputTypeJIT,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(resize)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(resize, type(inpt))
return kernel(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias)
elif isinstance(inpt, PIL.Image.Image):
if antialias is False:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
@_register_kernel_internal(resize, datapoints.Image)
def resize_image_tensor( def resize_image_tensor(
image: torch.Tensor, image: torch.Tensor,
size: List[int], size: List[int],
...@@ -274,6 +334,14 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N ...@@ -274,6 +334,14 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
return output return output
@_register_kernel_internal(resize, datapoints.Mask, datapoint_wrapper=False)
def _resize_mask_dispatch(
inpt: datapoints.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> datapoints.Mask:
output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size)
return datapoints.Mask.wrap_like(inpt, output)
def resize_bounding_boxes( def resize_bounding_boxes(
bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
...@@ -292,6 +360,17 @@ def resize_bounding_boxes( ...@@ -292,6 +360,17 @@ def resize_bounding_boxes(
) )
@_register_kernel_internal(resize, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _resize_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any
) -> datapoints.BoundingBoxes:
output, canvas_size = resize_bounding_boxes(
inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
@_register_kernel_internal(resize, datapoints.Video)
def resize_video( def resize_video(
video: torch.Tensor, video: torch.Tensor,
size: List[int], size: List[int],
...@@ -302,23 +381,54 @@ def resize_video( ...@@ -302,23 +381,54 @@ def resize_video(
return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
def resize( def affine(
inpt: datapoints._InputTypeJIT, inpt: datapoints._InputTypeJIT,
size: List[int], angle: Union[int, float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, translate: List[float],
max_size: Optional[int] = None, scale: float,
antialias: Optional[Union[str, bool]] = "warn", shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints._FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(resize) _log_api_usage_once(affine)
# TODO: consider deprecating integers from angle and shear on the future
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) return affine_image_tensor(
elif isinstance(inpt, datapoints._datapoint.Datapoint): inpt,
return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(affine, type(inpt))
return kernel(
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
if antialias is False: return affine_image_pil(
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") inpt,
return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -574,6 +684,7 @@ def _affine_grid( ...@@ -574,6 +684,7 @@ def _affine_grid(
return output_grid.view(1, oh, ow, 2) return output_grid.view(1, oh, ow, 2)
@_register_kernel_internal(affine, datapoints.Image)
def affine_image_tensor( def affine_image_tensor(
image: torch.Tensor, image: torch.Tensor,
angle: Union[int, float], angle: Union[int, float],
...@@ -763,6 +874,29 @@ def affine_bounding_boxes( ...@@ -763,6 +874,29 @@ def affine_bounding_boxes(
return out_box return out_box
@_register_kernel_internal(affine, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _affine_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
**kwargs,
) -> datapoints.BoundingBoxes:
output = affine_bounding_boxes(
inpt.as_subclass(torch.Tensor),
format=inpt.format,
canvas_size=inpt.canvas_size,
angle=angle,
translate=translate,
scale=scale,
shear=shear,
center=center,
)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
def affine_mask( def affine_mask(
mask: torch.Tensor, mask: torch.Tensor,
angle: Union[int, float], angle: Union[int, float],
...@@ -795,6 +929,30 @@ def affine_mask( ...@@ -795,6 +929,30 @@ def affine_mask(
return output return output
@_register_kernel_internal(affine, datapoints.Mask, datapoint_wrapper=False)
def _affine_mask_dispatch(
inpt: datapoints.Mask,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
fill: datapoints._FillTypeJIT = None,
center: Optional[List[float]] = None,
**kwargs,
) -> datapoints.Mask:
output = affine_mask(
inpt.as_subclass(torch.Tensor),
angle=angle,
translate=translate,
scale=scale,
shear=shear,
fill=fill,
center=center,
)
return datapoints.Mask.wrap_like(inpt, output)
@_register_kernel_internal(affine, datapoints.Video)
def affine_video( def affine_video(
video: torch.Tensor, video: torch.Tensor,
angle: Union[int, float], angle: Union[int, float],
...@@ -817,46 +975,24 @@ def affine_video( ...@@ -817,46 +975,24 @@ def affine_video(
) )
def affine( def rotate(
inpt: datapoints._InputTypeJIT, inpt: datapoints._InputTypeJIT,
angle: Union[int, float], angle: float,
translate: List[float],
scale: float,
shear: List[float],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: datapoints._FillTypeJIT = None, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(affine) _log_api_usage_once(rotate)
# TODO: consider deprecating integers from angle and shear on the future
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return affine_image_tensor( return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
inpt, elif isinstance(inpt, datapoints.Datapoint):
angle, kernel = _get_kernel(rotate, type(inpt))
translate=translate, return kernel(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.affine(
angle, translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center
)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return affine_image_pil( return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
inpt,
angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=interpolation,
fill=fill,
center=center,
)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -864,6 +1000,7 @@ def affine( ...@@ -864,6 +1000,7 @@ def affine(
) )
@_register_kernel_internal(rotate, datapoints.Image)
def rotate_image_tensor( def rotate_image_tensor(
image: torch.Tensor, image: torch.Tensor,
angle: float, angle: float,
...@@ -951,6 +1088,21 @@ def rotate_bounding_boxes( ...@@ -951,6 +1088,21 @@ def rotate_bounding_boxes(
) )
@_register_kernel_internal(rotate, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _rotate_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs
) -> datapoints.BoundingBoxes:
output, canvas_size = rotate_bounding_boxes(
inpt.as_subclass(torch.Tensor),
format=inpt.format,
canvas_size=inpt.canvas_size,
angle=angle,
expand=expand,
center=center,
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
def rotate_mask( def rotate_mask(
mask: torch.Tensor, mask: torch.Tensor,
angle: float, angle: float,
...@@ -979,6 +1131,20 @@ def rotate_mask( ...@@ -979,6 +1131,20 @@ def rotate_mask(
return output return output
@_register_kernel_internal(rotate, datapoints.Mask, datapoint_wrapper=False)
def _rotate_mask_dispatch(
inpt: datapoints.Mask,
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None,
**kwargs,
) -> datapoints.Mask:
output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center)
return datapoints.Mask.wrap_like(inpt, output)
@_register_kernel_internal(rotate, datapoints.Video)
def rotate_video( def rotate_video(
video: torch.Tensor, video: torch.Tensor,
angle: float, angle: float,
...@@ -990,23 +1156,23 @@ def rotate_video( ...@@ -990,23 +1156,23 @@ def rotate_video(
return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
def rotate( def pad(
inpt: datapoints._InputTypeJIT, inpt: datapoints._InputTypeJIT,
angle: float, padding: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, fill: Optional[Union[int, float, List[float]]] = None,
expand: bool = False, padding_mode: str = "constant",
center: Optional[List[float]] = None,
fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(rotate) _log_api_usage_once(pad)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.rotate(angle, interpolation=interpolation, expand=expand, fill=fill, center=center) elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(pad, type(inpt))
return kernel(inpt, padding, fill=fill, padding_mode=padding_mode)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -1038,6 +1204,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: ...@@ -1038,6 +1204,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
return [pad_left, pad_right, pad_top, pad_bottom] return [pad_left, pad_right, pad_top, pad_bottom]
@_register_kernel_internal(pad, datapoints.Image)
def pad_image_tensor( def pad_image_tensor(
image: torch.Tensor, image: torch.Tensor,
padding: List[int], padding: List[int],
...@@ -1139,6 +1306,7 @@ def _pad_with_vector_fill( ...@@ -1139,6 +1306,7 @@ def _pad_with_vector_fill(
pad_image_pil = _FP.pad pad_image_pil = _FP.pad
@_register_kernel_internal(pad, datapoints.Mask)
def pad_mask( def pad_mask(
mask: torch.Tensor, mask: torch.Tensor,
padding: List[int], padding: List[int],
...@@ -1192,6 +1360,21 @@ def pad_bounding_boxes( ...@@ -1192,6 +1360,21 @@ def pad_bounding_boxes(
return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
@_register_kernel_internal(pad, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _pad_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, padding: List[int], padding_mode: str = "constant", **kwargs
) -> datapoints.BoundingBoxes:
output, canvas_size = pad_bounding_boxes(
inpt.as_subclass(torch.Tensor),
format=inpt.format,
canvas_size=inpt.canvas_size,
padding=padding,
padding_mode=padding_mode,
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
@_register_kernel_internal(pad, datapoints.Video)
def pad_video( def pad_video(
video: torch.Tensor, video: torch.Tensor,
padding: List[int], padding: List[int],
...@@ -1201,22 +1384,17 @@ def pad_video( ...@@ -1201,22 +1384,17 @@ def pad_video(
return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode) return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode)
def pad( def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT:
inpt: datapoints._InputTypeJIT,
padding: List[int],
fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(pad) _log_api_usage_once(crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) return crop_image_tensor(inpt, top, left, height, width)
elif isinstance(inpt, datapoints.Datapoint):
elif isinstance(inpt, datapoints._datapoint.Datapoint): kernel = _get_kernel(crop, type(inpt))
return inpt.pad(padding, fill=fill, padding_mode=padding_mode) return kernel(inpt, top, left, height, width)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) return crop_image_pil(inpt, top, left, height, width)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -1224,6 +1402,7 @@ def pad( ...@@ -1224,6 +1402,7 @@ def pad(
) )
@_register_kernel_internal(crop, datapoints.Image)
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
h, w = image.shape[-2:] h, w = image.shape[-2:]
...@@ -1266,6 +1445,17 @@ def crop_bounding_boxes( ...@@ -1266,6 +1445,17 @@ def crop_bounding_boxes(
return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size return clamp_bounding_boxes(bounding_boxes, format=format, canvas_size=canvas_size), canvas_size
@_register_kernel_internal(crop, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _crop_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, top: int, left: int, height: int, width: int
) -> datapoints.BoundingBoxes:
output, canvas_size = crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
@_register_kernel_internal(crop, datapoints.Mask)
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
...@@ -1281,20 +1471,32 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) ...@@ -1281,20 +1471,32 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int)
return output return output
@_register_kernel_internal(crop, datapoints.Video)
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(video, top, left, height, width) return crop_image_tensor(video, top, left, height, width)
def crop(inpt: datapoints._InputTypeJIT, top: int, left: int, height: int, width: int) -> datapoints._InputTypeJIT: def perspective(
inpt: datapoints._InputTypeJIT,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(crop) _log_api_usage_once(perspective)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return crop_image_tensor(inpt, top, left, height, width) return perspective_image_tensor(
elif isinstance(inpt, datapoints._datapoint.Datapoint): inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
return inpt.crop(top, left, height, width) )
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(perspective, type(inpt))
return kernel(inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return crop_image_pil(inpt, top, left, height, width) return perspective_image_pil(
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -1349,6 +1551,7 @@ def _perspective_coefficients( ...@@ -1349,6 +1551,7 @@ def _perspective_coefficients(
raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.") raise ValueError("Either the startpoints/endpoints or the coefficients must have non `None` values.")
@_register_kernel_internal(perspective, datapoints.Image)
def perspective_image_tensor( def perspective_image_tensor(
image: torch.Tensor, image: torch.Tensor,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
...@@ -1503,6 +1706,25 @@ def perspective_bounding_boxes( ...@@ -1503,6 +1706,25 @@ def perspective_bounding_boxes(
).reshape(original_shape) ).reshape(original_shape)
@_register_kernel_internal(perspective, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _perspective_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
coefficients: Optional[List[float]] = None,
**kwargs,
) -> datapoints.BoundingBoxes:
output = perspective_bounding_boxes(
inpt.as_subclass(torch.Tensor),
format=inpt.format,
canvas_size=inpt.canvas_size,
startpoints=startpoints,
endpoints=endpoints,
coefficients=coefficients,
)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
def perspective_mask( def perspective_mask(
mask: torch.Tensor, mask: torch.Tensor,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
...@@ -1526,6 +1748,26 @@ def perspective_mask( ...@@ -1526,6 +1748,26 @@ def perspective_mask(
return output return output
@_register_kernel_internal(perspective, datapoints.Mask, datapoint_wrapper=False)
def _perspective_mask_dispatch(
inpt: datapoints.Mask,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
fill: datapoints._FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
**kwargs,
) -> datapoints.Mask:
output = perspective_mask(
inpt.as_subclass(torch.Tensor),
startpoints=startpoints,
endpoints=endpoints,
fill=fill,
coefficients=coefficients,
)
return datapoints.Mask.wrap_like(inpt, output)
@_register_kernel_internal(perspective, datapoints.Video)
def perspective_video( def perspective_video(
video: torch.Tensor, video: torch.Tensor,
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
...@@ -1539,28 +1781,25 @@ def perspective_video( ...@@ -1539,28 +1781,25 @@ def perspective_video(
) )
def perspective( def elastic(
inpt: datapoints._InputTypeJIT, inpt: datapoints._InputTypeJIT,
startpoints: Optional[List[List[int]]], displacement: torch.Tensor,
endpoints: Optional[List[List[int]]],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None, fill: datapoints._FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> datapoints._InputTypeJIT: ) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(perspective) _log_api_usage_once(elastic)
if not isinstance(displacement, torch.Tensor):
raise TypeError("Argument displacement should be a Tensor")
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return perspective_image_tensor( return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients elif isinstance(inpt, datapoints.Datapoint):
) kernel = _get_kernel(elastic, type(inpt))
elif isinstance(inpt, datapoints._datapoint.Datapoint): return kernel(inpt, displacement, interpolation=interpolation, fill=fill)
return inpt.perspective(
startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return perspective_image_pil( return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
inpt, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients
)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -1568,6 +1807,10 @@ def perspective( ...@@ -1568,6 +1807,10 @@ def perspective(
) )
elastic_transform = elastic
@_register_kernel_internal(elastic, datapoints.Image)
def elastic_image_tensor( def elastic_image_tensor(
image: torch.Tensor, image: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
...@@ -1699,6 +1942,16 @@ def elastic_bounding_boxes( ...@@ -1699,6 +1942,16 @@ def elastic_bounding_boxes(
).reshape(original_shape) ).reshape(original_shape)
@_register_kernel_internal(elastic, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _elastic_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, displacement: torch.Tensor, **kwargs
) -> datapoints.BoundingBoxes:
output = elastic_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement
)
return datapoints.BoundingBoxes.wrap_like(inpt, output)
def elastic_mask( def elastic_mask(
mask: torch.Tensor, mask: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
...@@ -1718,6 +1971,15 @@ def elastic_mask( ...@@ -1718,6 +1971,15 @@ def elastic_mask(
return output return output
@_register_kernel_internal(elastic, datapoints.Mask, datapoint_wrapper=False)
def _elastic_mask_dispatch(
inpt: datapoints.Mask, displacement: torch.Tensor, fill: datapoints._FillTypeJIT = None, **kwargs
) -> datapoints.Mask:
output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill)
return datapoints.Mask.wrap_like(inpt, output)
@_register_kernel_internal(elastic, datapoints.Video)
def elastic_video( def elastic_video(
video: torch.Tensor, video: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
...@@ -1727,24 +1989,17 @@ def elastic_video( ...@@ -1727,24 +1989,17 @@ def elastic_video(
return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
def elastic( def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT:
inpt: datapoints._InputTypeJIT,
displacement: torch.Tensor,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: datapoints._FillTypeJIT = None,
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(elastic) _log_api_usage_once(center_crop)
if not isinstance(displacement, torch.Tensor):
raise TypeError("Argument displacement should be a Tensor")
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) return center_crop_image_tensor(inpt, output_size)
elif isinstance(inpt, datapoints._datapoint.Datapoint): elif isinstance(inpt, datapoints.Datapoint):
return inpt.elastic(displacement, interpolation=interpolation, fill=fill) kernel = _get_kernel(center_crop, type(inpt))
return kernel(inpt, output_size)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) return center_crop_image_pil(inpt, output_size)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -1752,9 +2007,6 @@ def elastic( ...@@ -1752,9 +2007,6 @@ def elastic(
) )
elastic_transform = elastic
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number): if isinstance(output_size, numbers.Number):
s = int(output_size) s = int(output_size)
...@@ -1782,6 +2034,7 @@ def _center_crop_compute_crop_anchor( ...@@ -1782,6 +2034,7 @@ def _center_crop_compute_crop_anchor(
return crop_top, crop_left return crop_top, crop_left
@_register_kernel_internal(center_crop, datapoints.Image)
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
shape = image.shape shape = image.shape
...@@ -1831,6 +2084,17 @@ def center_crop_bounding_boxes( ...@@ -1831,6 +2084,17 @@ def center_crop_bounding_boxes(
) )
@_register_kernel_internal(center_crop, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _center_crop_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, output_size: List[int]
) -> datapoints.BoundingBoxes:
output, canvas_size = center_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
@_register_kernel_internal(center_crop, datapoints.Mask)
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
...@@ -1846,20 +2110,33 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor ...@@ -1846,20 +2110,33 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
return output return output
@_register_kernel_internal(center_crop, datapoints.Video)
def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor:
return center_crop_image_tensor(video, output_size) return center_crop_image_tensor(video, output_size)
def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datapoints._InputTypeJIT: def resized_crop(
inpt: datapoints._InputTypeJIT,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> datapoints._InputTypeJIT:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(center_crop) _log_api_usage_once(resized_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return center_crop_image_tensor(inpt, output_size) return resized_crop_image_tensor(
elif isinstance(inpt, datapoints._datapoint.Datapoint): inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation
return inpt.center_crop(output_size) )
elif isinstance(inpt, datapoints.Datapoint):
kernel = _get_kernel(resized_crop, type(inpt))
return kernel(inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return center_crop_image_pil(inpt, output_size) return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -1867,6 +2144,7 @@ def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datap ...@@ -1867,6 +2144,7 @@ def center_crop(inpt: datapoints._InputTypeJIT, output_size: List[int]) -> datap
) )
@_register_kernel_internal(resized_crop, datapoints.Image)
def resized_crop_image_tensor( def resized_crop_image_tensor(
image: torch.Tensor, image: torch.Tensor,
top: int, top: int,
...@@ -1904,8 +2182,18 @@ def resized_crop_bounding_boxes( ...@@ -1904,8 +2182,18 @@ def resized_crop_bounding_boxes(
width: int, width: int,
size: List[int], size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_boxes, _ = crop_bounding_boxes(bounding_boxes, format, top, left, height, width) bounding_boxes, canvas_size = crop_bounding_boxes(bounding_boxes, format, top, left, height, width)
return resize_bounding_boxes(bounding_boxes, canvas_size=(height, width), size=size) return resize_bounding_boxes(bounding_boxes, canvas_size=canvas_size, size=size)
@_register_kernel_internal(resized_crop, datapoints.BoundingBoxes, datapoint_wrapper=False)
def _resized_crop_bounding_boxes_dispatch(
inpt: datapoints.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> datapoints.BoundingBoxes:
output, canvas_size = resized_crop_bounding_boxes(
inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size
)
return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size)
def resized_crop_mask( def resized_crop_mask(
...@@ -1920,6 +2208,17 @@ def resized_crop_mask( ...@@ -1920,6 +2208,17 @@ def resized_crop_mask(
return resize_mask(mask, size) return resize_mask(mask, size)
@_register_kernel_internal(resized_crop, datapoints.Mask, datapoint_wrapper=False)
def _resized_crop_mask_dispatch(
inpt: datapoints.Mask, top: int, left: int, height: int, width: int, size: List[int], **kwargs
) -> datapoints.Mask:
output = resized_crop_mask(
inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size
)
return datapoints.Mask.wrap_like(inpt, output)
@_register_kernel_internal(resized_crop, datapoints.Video)
def resized_crop_video( def resized_crop_video(
video: torch.Tensor, video: torch.Tensor,
top: int, top: int,
...@@ -1935,27 +2234,26 @@ def resized_crop_video( ...@@ -1935,27 +2234,26 @@ def resized_crop_video(
) )
def resized_crop( @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True)
inpt: datapoints._InputTypeJIT, def five_crop(
top: int, inpt: datapoints._InputTypeJIT, size: List[int]
left: int, ) -> Tuple[
height: int, datapoints._InputTypeJIT,
width: int, datapoints._InputTypeJIT,
size: List[int], datapoints._InputTypeJIT,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, datapoints._InputTypeJIT,
antialias: Optional[Union[str, bool]] = "warn", datapoints._InputTypeJIT,
) -> datapoints._InputTypeJIT: ]:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(resized_crop) _log_api_usage_once(five_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return resized_crop_image_tensor( return five_crop_image_tensor(inpt, size)
inpt, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation elif isinstance(inpt, datapoints.Datapoint):
) kernel = _get_kernel(five_crop, type(inpt))
elif isinstance(inpt, datapoints._datapoint.Datapoint): return kernel(inpt, size)
return inpt.resized_crop(top, left, height, width, antialias=antialias, size=size, interpolation=interpolation)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return resized_crop_image_pil(inpt, top, left, height, width, size=size, interpolation=interpolation) return five_crop_image_pil(inpt, size)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
...@@ -1977,6 +2275,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: ...@@ -1977,6 +2275,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
return size return size
@_register_five_ten_crop_kernel(five_crop, datapoints.Image)
def five_crop_image_tensor( def five_crop_image_tensor(
image: torch.Tensor, size: List[int] image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
...@@ -2014,38 +2313,46 @@ def five_crop_image_pil( ...@@ -2014,38 +2313,46 @@ def five_crop_image_pil(
return tl, tr, bl, br, center return tl, tr, bl, br, center
@_register_five_ten_crop_kernel(five_crop, datapoints.Video)
def five_crop_video( def five_crop_video(
video: torch.Tensor, size: List[int] video: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return five_crop_image_tensor(video, size) return five_crop_image_tensor(video, size)
ImageOrVideoTypeJIT = Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT] @_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True)
def ten_crop(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False
def five_crop( ) -> Tuple[
inpt: ImageOrVideoTypeJIT, size: List[int] datapoints._InputTypeJIT,
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
datapoints._InputTypeJIT,
]:
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_log_api_usage_once(five_crop) _log_api_usage_once(ten_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt): if torch.jit.is_scripting() or is_simple_tensor(inpt):
return five_crop_image_tensor(inpt, size) return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
elif isinstance(inpt, datapoints.Image): elif isinstance(inpt, datapoints.Datapoint):
output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size) kernel = _get_kernel(ten_crop, type(inpt))
return tuple(datapoints.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value] return kernel(inpt, size, vertical_flip=vertical_flip)
elif isinstance(inpt, datapoints.Video):
output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
return tuple(datapoints.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size) return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
else: else:
raise TypeError( raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, " f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead." f"but got {type(inpt)} instead."
) )
@_register_five_ten_crop_kernel(ten_crop, datapoints.Image)
def ten_crop_image_tensor( def ten_crop_image_tensor(
image: torch.Tensor, size: List[int], vertical_flip: bool = False image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
...@@ -2099,6 +2406,7 @@ def ten_crop_image_pil( ...@@ -2099,6 +2406,7 @@ def ten_crop_image_pil(
return non_flipped + flipped return non_flipped + flipped
@_register_five_ten_crop_kernel(ten_crop, datapoints.Video)
def ten_crop_video( def ten_crop_video(
video: torch.Tensor, size: List[int], vertical_flip: bool = False video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[ ) -> Tuple[
...@@ -2114,37 +2422,3 @@ def ten_crop_video( ...@@ -2114,37 +2422,3 @@ def ten_crop_video(
torch.Tensor, torch.Tensor,
]: ]:
return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip) return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)
def ten_crop(
inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> Tuple[
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
]:
if not torch.jit.is_scripting():
_log_api_usage_once(ten_crop)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
elif isinstance(inpt, datapoints.Image):
output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return tuple(datapoints.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
elif isinstance(inpt, datapoints.Video):
output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return tuple(datapoints.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
elif isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)
else:
raise TypeError(
f"Input can either be a plain tensor, an `Image` or `Video` datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment