Unverified Commit 7251769f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Transforms without dispatcher (#5421)



* add prototype transforms that don't need dispatchers

* cleanup

* remove legacy_transform decorator

* remove legacy classes

* remove explicit param passing

* streamline extra_repr

* remove obsolete ._supports() method

* cleanup

* remove Query

* cleanup

* fix tests

* kernels -> functional

* move image size and num channels extraction to functional

* extend legacy function to extract image size and num channels

* implement dispatching for auto augment

* fix auto augment dispatch

* revert some naming changes

* remove ability to pass params to autoaugment

* fix legacy image size extraction

* align prototype.transforms.functional with transforms.functional

* cleanup

* fix image size and channels extraction

* fix affine and rotate

* revert image size to (width, height)

* Minor corrections
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent f15ba56f
import itertools import itertools
import PIL.Image
import pytest import pytest
import torch import torch
from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels
from torchvision.prototype import transforms, features from torchvision.prototype import transforms, features
from torchvision.transforms.functional import to_pil_image from torchvision.transforms.functional import to_pil_image
...@@ -25,15 +24,6 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs): ...@@ -25,15 +24,6 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
yield bounding_box.data yield bounding_box.data
INPUT_CREATIONS_FNS = {
features.Image: make_images,
features.BoundingBox: make_bounding_boxes,
features.OneHotLabel: make_one_hot_labels,
torch.Tensor: make_vanilla_tensor_images,
PIL.Image.Image: make_pil_images,
}
def parametrize(transforms_with_inputs): def parametrize(transforms_with_inputs):
return pytest.mark.parametrize( return pytest.mark.parametrize(
("transform", "input"), ("transform", "input"),
...@@ -52,14 +42,20 @@ def parametrize(transforms_with_inputs): ...@@ -52,14 +42,20 @@ def parametrize(transforms_with_inputs):
def parametrize_from_transforms(*transforms): def parametrize_from_transforms(*transforms):
transforms_with_inputs = [] transforms_with_inputs = []
for transform in transforms: for transform in transforms:
dispatcher = transform._DISPATCHER for creation_fn in [
if dispatcher is None: make_images,
continue make_bounding_boxes,
make_one_hot_labels,
for type_ in dispatcher._kernels: make_vanilla_tensor_images,
make_pil_images,
]:
inputs = list(creation_fn())
try: try:
inputs = INPUT_CREATIONS_FNS[type_]() output = transform(inputs[0])
except KeyError: except Exception:
continue
else:
if output is inputs[0]:
continue continue
transforms_with_inputs.append((transform, inputs)) transforms_with_inputs.append((transform, inputs))
...@@ -69,7 +65,7 @@ def parametrize_from_transforms(*transforms): ...@@ -69,7 +65,7 @@ def parametrize_from_transforms(*transforms):
class TestSmoke: class TestSmoke:
@parametrize_from_transforms( @parametrize_from_transforms(
transforms.RandomErasing(), transforms.RandomErasing(p=1.0),
transforms.HorizontalFlip(), transforms.HorizontalFlip(),
transforms.Resize([16, 16]), transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]), transforms.CenterCrop([16, 16]),
...@@ -141,35 +137,6 @@ class TestSmoke: ...@@ -141,35 +137,6 @@ class TestSmoke:
def test_normalize(self, transform, input): def test_normalize(self, transform, input):
transform(input) transform(input)
@parametrize(
[
(
transforms.ConvertColorSpace("grayscale"),
itertools.chain(
make_images(),
make_vanilla_tensor_images(color_spaces=["rgb"]),
make_pil_images(color_spaces=["rgb"]),
),
)
]
)
def test_convert_bounding_color_space(self, transform, input):
transform(input)
@parametrize(
[
(
transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"),
itertools.chain(
make_bounding_boxes(),
make_vanilla_tensor_bounding_boxes(formats=["xywh"]),
),
)
]
)
def test_convert_bounding_box_format(self, transform, input):
transform(input)
@parametrize( @parametrize(
[ [
( (
......
...@@ -3,7 +3,7 @@ import itertools ...@@ -3,7 +3,7 @@ import itertools
import pytest import pytest
import torch.testing import torch.testing
import torchvision.prototype.transforms.kernels as K import torchvision.prototype.transforms.functional as F
from torch import jit from torch import jit
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torchvision.prototype import features from torchvision.prototype import features
...@@ -134,10 +134,10 @@ class SampleInput: ...@@ -134,10 +134,10 @@ class SampleInput:
self.kwargs = kwargs self.kwargs = kwargs
class KernelInfo: class FunctionalInfo:
def __init__(self, name, *, sample_inputs_fn): def __init__(self, name, *, sample_inputs_fn):
self.name = name self.name = name
self.kernel = getattr(K, name) self.functional = getattr(F, name)
self._sample_inputs_fn = sample_inputs_fn self._sample_inputs_fn = sample_inputs_fn
def sample_inputs(self): def sample_inputs(self):
...@@ -146,21 +146,21 @@ class KernelInfo: ...@@ -146,21 +146,21 @@ class KernelInfo:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput): if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput):
sample_input = args[0] sample_input = args[0]
return self.kernel(*sample_input.args, **sample_input.kwargs) return self.functional(*sample_input.args, **sample_input.kwargs)
return self.kernel(*args, **kwargs) return self.functional(*args, **kwargs)
KERNEL_INFOS = [] FUNCTIONAL_INFOS = []
def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn):
KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn)) FUNCTIONAL_INFOS.append(FunctionalInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn))
return sample_inputs_fn return sample_inputs_fn
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def horizontal_flip_image(): def horizontal_flip_image_tensor():
for image in make_images(): for image in make_images():
yield SampleInput(image) yield SampleInput(image)
...@@ -172,12 +172,12 @@ def horizontal_flip_bounding_box(): ...@@ -172,12 +172,12 @@ def horizontal_flip_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def resize_image(): def resize_image_tensor():
for image, interpolation in itertools.product( for image, interpolation in itertools.product(
make_images(), make_images(),
[ [
K.InterpolationMode.BILINEAR, F.InterpolationMode.BILINEAR,
K.InterpolationMode.NEAREST, F.InterpolationMode.NEAREST,
], ],
): ):
height, width = image.shape[-2:] height, width = image.shape[-2:]
...@@ -200,20 +200,20 @@ def resize_bounding_box(): ...@@ -200,20 +200,20 @@ def resize_bounding_box():
class TestKernelsCommon: class TestKernelsCommon:
@pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name) @pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name)
def test_scriptable(self, kernel_info): def test_scriptable(self, functional_info):
jit.script(kernel_info.kernel) jit.script(functional_info.functional)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("kernel_info", "sample_input"), ("functional_info", "sample_input"),
[ [
pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}") pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
for kernel_info in KERNEL_INFOS for functional_info in FUNCTIONAL_INFOS
for idx, sample_input in enumerate(kernel_info.sample_inputs()) for idx, sample_input in enumerate(functional_info.sample_inputs())
], ],
) )
def test_eager_vs_scripted(self, kernel_info, sample_input): def test_eager_vs_scripted(self, functional_info, sample_input):
eager = kernel_info(sample_input) eager = functional_info(sample_input)
scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs) scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
torch.testing.assert_close(eager, scripted) torch.testing.assert_close(eager, scripted)
...@@ -41,7 +41,7 @@ class BoundingBox(_Feature): ...@@ -41,7 +41,7 @@ class BoundingBox(_Feature):
# promote this out of the prototype state # promote this out of the prototype state
# import at runtime to avoid cyclic imports # import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.kernels import convert_bounding_box_format from torchvision.prototype.transforms.functional import convert_bounding_box_format
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat[format] format = BoundingBoxFormat[format]
......
...@@ -43,7 +43,7 @@ class EncodedImage(EncodedData): ...@@ -43,7 +43,7 @@ class EncodedImage(EncodedData):
# promote this out of the prototype state # promote this out of the prototype state
# import at runtime to avoid cyclic imports # import at runtime to avoid cyclic imports
from torchvision.prototype.transforms.kernels import decode_image_with_pil from torchvision.prototype.transforms.functional import decode_image_with_pil
return Image(decode_image_with_pil(self)) return Image(decode_image_with_pil(self))
......
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip
from . import kernels # usort: skip
from . import functional # usort: skip from . import functional # usort: skip
from ._transform import Transform # usort: skip from ._transform import Transform # usort: skip
from ._augment import RandomErasing, RandomMixup, RandomCutmix from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from ._type_conversion import DecodeImage, LabelToOneHot from ._type_conversion import DecodeImage, LabelToOneHot
...@@ -3,7 +3,6 @@ import numbers ...@@ -3,7 +3,6 @@ import numbers
import warnings import warnings
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.transforms import Transform, functional as F
...@@ -12,9 +11,6 @@ from ._utils import query_image ...@@ -12,9 +11,6 @@ from ._utils import query_image
class RandomErasing(Transform): class RandomErasing(Transform):
_DISPATCHER = F.erase
_FAIL_TYPES = {PIL.Image.Image, features.BoundingBox, features.SegmentationMask}
def __init__( def __init__(
self, self,
p: float = 0.5, p: float = 0.5,
...@@ -45,8 +41,8 @@ class RandomErasing(Transform): ...@@ -45,8 +41,8 @@ class RandomErasing(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) image = query_image(sample)
img_h, img_w = F.get_image_size(image)
img_c = F.get_image_num_channels(image) img_c = F.get_image_num_channels(image)
img_w, img_h = F.get_image_size(image)
if isinstance(self.value, (int, float)): if isinstance(self.value, (int, float)):
value = [self.value] value = [self.value]
...@@ -93,16 +89,24 @@ class RandomErasing(Transform): ...@@ -93,16 +89,24 @@ class RandomErasing(Transform):
return dict(zip("ijhwv", (i, j, h, w, v))) return dict(zip("ijhwv", (i, j, h, w, v)))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if torch.rand(1) >= self.p: if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.erase_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
return F.erase_image_tensor(input, **params)
else:
return input return input
return super()._transform(input, params) def forward(self, *inputs: Any) -> Any:
if torch.rand(1) >= self.p:
return inputs if len(inputs) > 1 else inputs[0]
return super().forward(*inputs)
class RandomMixup(Transform): class RandomMixup(Transform):
_DISPATCHER = F.mixup
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(self, *, alpha: float) -> None: def __init__(self, *, alpha: float) -> None:
super().__init__() super().__init__()
self.alpha = alpha self.alpha = alpha
...@@ -111,11 +115,20 @@ class RandomMixup(Transform): ...@@ -111,11 +115,20 @@ class RandomMixup(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) return dict(lam=float(self._dist.sample(())))
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.mixup_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
output = F.mixup_one_hot_label(input, **params)
return features.OneHotLabel.new_like(input, output)
else:
return input
class RandomCutmix(Transform):
_DISPATCHER = F.cutmix
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
class RandomCutmix(Transform):
def __init__(self, *, alpha: float) -> None: def __init__(self, *, alpha: float) -> None:
super().__init__() super().__init__()
self.alpha = alpha self.alpha = alpha
...@@ -125,7 +138,7 @@ class RandomCutmix(Transform): ...@@ -125,7 +138,7 @@ class RandomCutmix(Transform):
lam = float(self._dist.sample(())) lam = float(self._dist.sample(()))
image = query_image(sample) image = query_image(sample)
H, W = F.get_image_size(image) W, H = F.get_image_size(image)
r_x = torch.randint(W, ()) r_x = torch.randint(W, ())
r_y = torch.randint(H, ()) r_y = torch.randint(H, ())
...@@ -143,3 +156,15 @@ class RandomCutmix(Transform): ...@@ -143,3 +156,15 @@ class RandomCutmix(Transform):
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
return dict(box=box, lam_adjusted=lam_adjusted) return dict(box=box, lam_adjusted=lam_adjusted)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.cutmix_image_tensor(input, box=params["box"])
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"])
return features.OneHotLabel.new_like(input, output)
else:
return input
...@@ -21,9 +21,54 @@ class _AutoAugmentBase(Transform): ...@@ -21,9 +21,54 @@ class _AutoAugmentBase(Transform):
self.interpolation = interpolation self.interpolation = interpolation
self.fill = fill self.fill = fill
_DISPATCHER_MAP: Dict[str, Callable[[Any, float, InterpolationMode, Optional[List[float]]], Any]] = { def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
"Identity": lambda input, magnitude, interpolation, fill: input, keys = tuple(dct.keys())
"ShearX": lambda input, magnitude, interpolation, fill: F.affine( key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _apply_transform(self, sample: Any, transform_id: str, magnitude: float) -> Any:
def dispatch(
image_tensor_kernel: Callable,
image_pil_kernel: Callable,
input: Any,
*args: Any,
**kwargs: Any,
) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = image_tensor_kernel(input, *args, **kwargs)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
return image_tensor_kernel(input, *args, **kwargs)
elif isinstance(input, PIL.Image.Image):
return image_pil_kernel(input, *args, **kwargs)
else:
return input
image = query_image(sample)
num_channels = F.get_image_num_channels(image)
fill = self.fill
if isinstance(fill, (int, float)):
fill = [float(fill)] * num_channels
elif fill is not None:
fill = [float(f) for f in fill]
interpolation = self.interpolation
def transform(input: Any) -> Any:
if type(input) in {features.BoundingBox, features.SegmentationMask}:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
elif not (type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image)):
return input
if transform_id == "Identity":
return input
elif transform_id == "ShearX":
return dispatch(
F.affine_image_tensor,
F.affine_image_pil,
input, input,
angle=0.0, angle=0.0,
translate=[0, 0], translate=[0, 0],
...@@ -31,8 +76,11 @@ class _AutoAugmentBase(Transform): ...@@ -31,8 +76,11 @@ class _AutoAugmentBase(Transform):
shear=[math.degrees(magnitude), 0.0], shear=[math.degrees(magnitude), 0.0],
interpolation=interpolation, interpolation=interpolation,
fill=fill, fill=fill,
), )
"ShearY": lambda input, magnitude, interpolation, fill: F.affine( elif transform_id == "ShearY":
return dispatch(
F.affine_image_tensor,
F.affine_image_pil,
input, input,
angle=0.0, angle=0.0,
translate=[0, 0], translate=[0, 0],
...@@ -40,8 +88,11 @@ class _AutoAugmentBase(Transform): ...@@ -40,8 +88,11 @@ class _AutoAugmentBase(Transform):
shear=[0.0, math.degrees(magnitude)], shear=[0.0, math.degrees(magnitude)],
interpolation=interpolation, interpolation=interpolation,
fill=fill, fill=fill,
), )
"TranslateX": lambda input, magnitude, interpolation, fill: F.affine( elif transform_id == "TranslateX":
return dispatch(
F.affine_image_tensor,
F.affine_image_pil,
input, input,
angle=0.0, angle=0.0,
translate=[int(magnitude), 0], translate=[int(magnitude), 0],
...@@ -49,8 +100,11 @@ class _AutoAugmentBase(Transform): ...@@ -49,8 +100,11 @@ class _AutoAugmentBase(Transform):
shear=[0.0, 0.0], shear=[0.0, 0.0],
interpolation=interpolation, interpolation=interpolation,
fill=fill, fill=fill,
), )
"TranslateY": lambda input, magnitude, interpolation, fill: F.affine( elif transform_id == "TranslateY":
return dispatch(
F.affine_image_tensor,
F.affine_image_pil,
input, input,
angle=0.0, angle=0.0,
translate=[0, int(magnitude)], translate=[0, int(magnitude)],
...@@ -58,54 +112,46 @@ class _AutoAugmentBase(Transform): ...@@ -58,54 +112,46 @@ class _AutoAugmentBase(Transform):
shear=[0.0, 0.0], shear=[0.0, 0.0],
interpolation=interpolation, interpolation=interpolation,
fill=fill, fill=fill,
), )
"Rotate": lambda input, magnitude, interpolation, fill: F.rotate(input, angle=magnitude), elif transform_id == "Rotate":
"Brightness": lambda input, magnitude, interpolation, fill: F.adjust_brightness( return dispatch(F.rotate_image_tensor, F.rotate_image_pil, input, angle=magnitude)
input, brightness_factor=1.0 + magnitude elif transform_id == "Brightness":
), return dispatch(
"Color": lambda input, magnitude, interpolation, fill: F.adjust_saturation( F.adjust_brightness_image_tensor,
input, saturation_factor=1.0 + magnitude F.adjust_brightness_image_pil,
), input,
"Contrast": lambda input, magnitude, interpolation, fill: F.adjust_contrast( brightness_factor=1.0 + magnitude,
input, contrast_factor=1.0 + magnitude )
), elif transform_id == "Color":
"Sharpness": lambda input, magnitude, interpolation, fill: F.adjust_sharpness( return dispatch(
input, sharpness_factor=1.0 + magnitude F.adjust_saturation_image_tensor,
), F.adjust_saturation_image_pil,
"Posterize": lambda input, magnitude, interpolation, fill: F.posterize(input, bits=int(magnitude)), input,
"Solarize": lambda input, magnitude, interpolation, fill: F.solarize(input, threshold=magnitude), saturation_factor=1.0 + magnitude,
"AutoContrast": lambda input, magnitude, interpolation, fill: F.autocontrast(input), )
"Equalize": lambda input, magnitude, interpolation, fill: F.equalize(input), elif transform_id == "Contrast":
"Invert": lambda input, magnitude, interpolation, fill: F.invert(input), return dispatch(
} F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, input, contrast_factor=1.0 + magnitude
)
def _get_params(self, sample: Any) -> Dict[str, Any]: elif transform_id == "Sharpness":
image = query_image(sample) return dispatch(
num_channels = F.get_image_num_channels(image) F.adjust_sharpness_image_tensor,
F.adjust_sharpness_image_pil,
fill = self.fill input,
if isinstance(fill, (int, float)): sharpness_factor=1.0 + magnitude,
fill = [float(fill)] * num_channels )
elif fill is not None: elif transform_id == "Posterize":
fill = [float(f) for f in fill] return dispatch(F.posterize_image_tensor, F.posterize_image_pil, input, bits=int(magnitude))
elif transform_id == "Solarize":
return dict(interpolation=self.interpolation, fill=fill) return dispatch(F.solarize_image_tensor, F.solarize_image_pil, input, threshold=magnitude)
elif transform_id == "AutoContrast":
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: return dispatch(F.autocontrast_image_tensor, F.autocontrast_image_pil, input)
keys = tuple(dct.keys()) elif transform_id == "Equalize":
key = keys[int(torch.randint(len(keys), ()))] return dispatch(F.equalize_image_tensor, F.equalize_image_pil, input)
return key, dct[key] elif transform_id == "Invert":
return dispatch(F.invert_image_tensor, F.invert_image_pil, input)
def _apply_transform(self, sample: Any, params: Dict[str, Any], transform_id: str, magnitude: float) -> Any:
dispatcher = self._DISPATCHER_MAP[transform_id]
def transform(input: Any) -> Any:
if type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image):
return dispatcher(input, magnitude, params["interpolation"], params["fill"])
elif type(input) in {features.BoundingBox, features.SegmentationMask}:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
else: else:
return input raise ValueError(f"No transform available for {transform_id}")
return apply_recursively(transform, sample) return apply_recursively(transform, sample)
...@@ -114,7 +160,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -114,7 +160,7 @@ class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = { _AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
...@@ -228,9 +274,8 @@ class AutoAugment(_AutoAugmentBase): ...@@ -228,9 +274,8 @@ class AutoAugment(_AutoAugmentBase):
else: else:
raise ValueError(f"The provided policy {policy} is not recognized.") raise ValueError(f"The provided policy {policy} is not recognized.")
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample) image = query_image(sample)
image_size = F.get_image_size(image) image_size = F.get_image_size(image)
...@@ -251,7 +296,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -251,7 +296,7 @@ class AutoAugment(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
sample = self._apply_transform(sample, params, transform_id, magnitude) sample = self._apply_transform(sample, transform_id, magnitude)
return sample return sample
...@@ -261,7 +306,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -261,7 +306,7 @@ class RandAugment(_AutoAugmentBase):
"Identity": (lambda num_bins, image_size: None, False), "Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
...@@ -285,9 +330,8 @@ class RandAugment(_AutoAugmentBase): ...@@ -285,9 +330,8 @@ class RandAugment(_AutoAugmentBase):
self.magnitude = magnitude self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample) image = query_image(sample)
image_size = F.get_image_size(image) image_size = F.get_image_size(image)
...@@ -303,7 +347,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -303,7 +347,7 @@ class RandAugment(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
sample = self._apply_transform(sample, params, transform_id, magnitude) sample = self._apply_transform(sample, transform_id, magnitude)
return sample return sample
...@@ -335,9 +379,8 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -335,9 +379,8 @@ class TrivialAugmentWide(_AutoAugmentBase):
super().__init__(**kwargs) super().__init__(**kwargs)
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
params = params or self._get_params(sample)
image = query_image(sample) image = query_image(sample)
image_size = F.get_image_size(image) image_size = F.get_image_size(image)
...@@ -352,4 +395,4 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -352,4 +395,4 @@ class TrivialAugmentWide(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
return self._apply_transform(sample, params, transform_id, magnitude) return self._apply_transform(sample, transform_id, magnitude)
from typing import Any, Optional, Dict from typing import Any
import torch import torch
...@@ -6,13 +6,13 @@ from ._transform import Transform ...@@ -6,13 +6,13 @@ from ._transform import Transform
class Compose(Transform): class Compose(Transform):
def __init__(self, *transforms: Transform): def __init__(self, *transforms: Transform) -> None:
super().__init__() super().__init__()
self.transforms = transforms self.transforms = transforms
for idx, transform in enumerate(transforms): for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform) self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override] def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms: for transform in self.transforms:
sample = transform(sample) sample = transform(sample)
...@@ -25,38 +25,38 @@ class RandomApply(Transform): ...@@ -25,38 +25,38 @@ class RandomApply(Transform):
self.transform = transform self.transform = transform
self.p = p self.p = p
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
if float(torch.rand(())) < self.p: if float(torch.rand(())) < self.p:
return sample return sample
return self.transform(sample, params=params) return self.transform(sample)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return f"p={self.p}" return f"p={self.p}"
class RandomChoice(Transform): class RandomChoice(Transform):
def __init__(self, *transforms: Transform): def __init__(self, *transforms: Transform) -> None:
super().__init__() super().__init__()
self.transforms = transforms self.transforms = transforms
for idx, transform in enumerate(transforms): for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform) self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override] def forward(self, *inputs: Any) -> Any:
idx = int(torch.randint(len(self.transforms), size=())) idx = int(torch.randint(len(self.transforms), size=()))
transform = self.transforms[idx] transform = self.transforms[idx]
return transform(*inputs) return transform(*inputs)
class RandomOrder(Transform): class RandomOrder(Transform):
def __init__(self, *transforms: Transform): def __init__(self, *transforms: Transform) -> None:
super().__init__() super().__init__()
self.transforms = transforms self.transforms = transforms
for idx, transform in enumerate(transforms): for idx, transform in enumerate(transforms):
self.add_module(str(idx), transform) self.add_module(str(idx), transform)
def forward(self, *inputs: Any) -> Any: # type: ignore[override] def forward(self, *inputs: Any) -> Any:
for idx in torch.randperm(len(self.transforms)): for idx in torch.randperm(len(self.transforms)):
transform = self.transforms[idx] transform = self.transforms[idx]
inputs = transform(*inputs) inputs = transform(*inputs)
......
...@@ -2,6 +2,7 @@ import math ...@@ -2,6 +2,7 @@ import math
import warnings import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast from typing import Any, Dict, List, Union, Sequence, Tuple, cast
import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
...@@ -11,41 +12,69 @@ from ._utils import query_image ...@@ -11,41 +12,69 @@ from ._utils import query_image
class HorizontalFlip(Transform): class HorizontalFlip(Transform):
_DISPATCHER = F.horizontal_flip def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.horizontal_flip_image_tensor(input)
return features.Image.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size)
return features.BoundingBox.new_like(input, output)
elif isinstance(input, PIL.Image.Image):
return F.horizontal_flip_image_pil(input)
elif isinstance(input, torch.Tensor):
return F.horizontal_flip_image_tensor(input)
else:
return input
class Resize(Transform): class Resize(Transform):
_DISPATCHER = F.resize
def __init__( def __init__(
self, self,
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
self.size = size self.size = [size] if isinstance(size, int) else list(size)
self.interpolation = interpolation self.interpolation = interpolation
def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
return dict(size=self.size, interpolation=self.interpolation) if isinstance(input, features.Image):
output = F.resize_image_tensor(input, self.size, interpolation=self.interpolation)
return features.Image.new_like(input, output)
elif isinstance(input, features.SegmentationMask):
output = F.resize_segmentation_mask(input, self.size)
return features.SegmentationMask.new_like(input, output)
elif isinstance(input, features.BoundingBox):
output = F.resize_bounding_box(input, self.size, image_size=input.image_size)
return features.BoundingBox.new_like(input, output, image_size=self.size)
elif isinstance(input, PIL.Image.Image):
return F.resize_image_pil(input, self.size, interpolation=self.interpolation)
elif isinstance(input, torch.Tensor):
return F.resize_image_tensor(input, self.size, interpolation=self.interpolation)
else:
return input
class CenterCrop(Transform): class CenterCrop(Transform):
_DISPATCHER = F.center_crop
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__(self, output_size: List[int]): def __init__(self, output_size: List[int]):
super().__init__() super().__init__()
self.output_size = output_size self.output_size = output_size
def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
return dict(output_size=self.output_size) if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.center_crop_image_tensor(input, self.output_size)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
return F.center_crop_image_tensor(input, self.output_size)
elif isinstance(input, PIL.Image.Image):
return F.center_crop_image_pil(input, self.output_size)
else:
return input
class RandomResizedCrop(Transform): class RandomResizedCrop(Transform):
_DISPATCHER = F.resized_crop
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}
def __init__( def __init__(
self, self,
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
...@@ -80,7 +109,7 @@ class RandomResizedCrop(Transform): ...@@ -80,7 +109,7 @@ class RandomResizedCrop(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) image = query_image(sample)
height, width = F.get_image_size(image) width, height = F.get_image_size(image)
area = height * width area = height * width
log_ratio = torch.log(torch.tensor(self.ratio)) log_ratio = torch.log(torch.tensor(self.ratio))
...@@ -115,4 +144,19 @@ class RandomResizedCrop(Transform): ...@@ -115,4 +144,19 @@ class RandomResizedCrop(Transform):
i = (height - h) // 2 i = (height - h) // 2
j = (width - w) // 2 j = (width - w) // 2
return dict(top=i, left=j, height=h, width=w, size=self.size) return dict(top=i, left=j, height=h, width=w)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
output = F.resized_crop_image_tensor(
input, **params, size=list(self.size), interpolation=self.interpolation
)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation)
elif isinstance(input, PIL.Image.Image):
return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation)
else:
return input
from typing import Union, Any, Dict, Optional from typing import Union, Any, Dict, Optional
import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.transforms import Transform, functional as F
...@@ -7,24 +8,18 @@ from torchvision.transforms.functional import convert_image_dtype ...@@ -7,24 +8,18 @@ from torchvision.transforms.functional import convert_image_dtype
class ConvertBoundingBoxFormat(Transform): class ConvertBoundingBoxFormat(Transform):
_DISPATCHER = F.convert_format def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None:
def __init__(
self,
format: Union[str, features.BoundingBoxFormat],
old_format: Optional[Union[str, features.BoundingBoxFormat]] = None,
) -> None:
super().__init__() super().__init__()
if isinstance(format, str): if isinstance(format, str):
format = features.BoundingBoxFormat[format] format = features.BoundingBoxFormat[format]
self.format = format self.format = format
if isinstance(old_format, str): def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
old_format = features.BoundingBoxFormat[old_format] if type(input) is features.BoundingBox:
self.old_format = old_format output = F.convert_bounding_box_format(input, old_format=input.format, new_format=params["format"])
return features.BoundingBox.new_like(input, output, format=params["format"])
def _get_params(self, sample: Any) -> Dict[str, Any]: else:
return dict(format=self.format, old_format=self.old_format) return input
class ConvertImageDtype(Transform): class ConvertImageDtype(Transform):
...@@ -33,21 +28,50 @@ class ConvertImageDtype(Transform): ...@@ -33,21 +28,50 @@ class ConvertImageDtype(Transform):
self.dtype = dtype self.dtype = dtype
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.Image): if type(input) is features.Image:
return input
output = convert_image_dtype(input, dtype=self.dtype) output = convert_image_dtype(input, dtype=self.dtype)
return features.Image.new_like(input, output, dtype=self.dtype) return features.Image.new_like(input, output, dtype=self.dtype)
else:
return input
class ConvertColorSpace(Transform): class ConvertImageColorSpace(Transform):
_DISPATCHER = F.convert_color_space def __init__(
self,
def __init__(self, color_space: Union[str, features.ColorSpace]) -> None: color_space: Union[str, features.ColorSpace],
old_color_space: Optional[Union[str, features.ColorSpace]] = None,
) -> None:
super().__init__() super().__init__()
if isinstance(color_space, str): if isinstance(color_space, str):
color_space = features.ColorSpace[color_space] color_space = features.ColorSpace[color_space]
self.color_space = color_space self.color_space = color_space
def _get_params(self, sample: Any) -> Dict[str, Any]: if isinstance(old_color_space, str):
return dict(color_space=self.color_space) old_color_space = features.ColorSpace[old_color_space]
self.old_color_space = old_color_space
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.convert_image_color_space_tensor(
input, old_color_space=input.color_space, new_color_space=self.color_space
)
return features.Image.new_like(input, output, color_space=self.color_space)
elif isinstance(input, torch.Tensor):
if self.old_color_space is None:
raise RuntimeError("")
return F.convert_image_color_space_tensor(
input, old_color_space=self.old_color_space, new_color_space=self.color_space
)
elif isinstance(input, PIL.Image.Image):
old_color_space = {
"L": features.ColorSpace.GRAYSCALE,
"RGB": features.ColorSpace.RGB,
}.get(input.mode, features.ColorSpace.OTHER)
return F.convert_image_color_space_pil(
input, old_color_space=old_color_space, new_color_space=self.color_space
)
else:
return input
...@@ -17,10 +17,10 @@ class Lambda(Transform): ...@@ -17,10 +17,10 @@ class Lambda(Transform):
self.types = types self.types = types
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, self.types): if type(input) in self.types:
return input
return self.fn(input) return self.fn(input)
else:
return input
def extra_repr(self) -> str: def extra_repr(self) -> str:
extras = [] extras = []
...@@ -32,15 +32,18 @@ class Lambda(Transform): ...@@ -32,15 +32,18 @@ class Lambda(Transform):
class Normalize(Transform): class Normalize(Transform):
_DISPATCHER = F.normalize
def __init__(self, mean: List[float], std: List[float]): def __init__(self, mean: List[float], std: List[float]):
super().__init__() super().__init__()
self.mean = mean self.mean = mean
self.std = std self.std = std
def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
return dict(mean=self.mean, std=self.std) if isinstance(input, torch.Tensor):
# We don't need to differentiate between vanilla tensors and features.Image's here, since the result of the
# normalization transform is no longer a features.Image
return F.normalize_image_tensor(input, mean=self.mean, std=self.std)
else:
return input
class ToDtype(Lambda): class ToDtype(Lambda):
......
import enum import enum
import functools import functools
from typing import Any, Dict, Optional, Set, Type from typing import Any, Dict
from torch import nn from torch import nn
from torchvision.prototype.utils._internal import apply_recursively from torchvision.prototype.utils._internal import apply_recursively
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from .functional._utils import Dispatcher
class Transform(nn.Module): class Transform(nn.Module):
_DISPATCHER: Optional[Dispatcher] = None
_FAIL_TYPES: Set[Type] = set()
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
...@@ -21,19 +16,11 @@ class Transform(nn.Module): ...@@ -21,19 +16,11 @@ class Transform(nn.Module):
return dict() return dict()
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not self._DISPATCHER: raise NotImplementedError
raise NotImplementedError()
if input in self._DISPATCHER:
return self._DISPATCHER(input, **params)
elif type(input) in self._FAIL_TYPES:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
else:
return input
def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
return apply_recursively(functools.partial(self._transform, params=params or self._get_params(sample)), sample) return apply_recursively(functools.partial(self._transform, params=self._get_params(sample)), sample)
def extra_repr(self) -> str: def extra_repr(self) -> str:
extra = [] extra = []
......
from typing import Any, Dict from typing import Any, Dict
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, kernels as K from torchvision.prototype.transforms import Transform, functional as F
class DecodeImage(Transform): class DecodeImage(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.EncodedImage): if type(input) is features.EncodedImage:
output = F.decode_image_with_pil(input)
return features.Image(output)
else:
return input return input
return features.Image(K.decode_image_with_pil(input))
class LabelToOneHot(Transform): class LabelToOneHot(Transform):
def __init__(self, num_categories: int = -1): def __init__(self, num_categories: int = -1):
...@@ -18,15 +19,14 @@ class LabelToOneHot(Transform): ...@@ -18,15 +19,14 @@ class LabelToOneHot(Transform):
self.num_categories = num_categories self.num_categories = num_categories
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if not isinstance(input, features.Label): if type(input) is features.Label:
return input
num_categories = self.num_categories num_categories = self.num_categories
if num_categories == -1 and input.categories is not None: if num_categories == -1 and input.categories is not None:
num_categories = len(input.categories) num_categories = len(input.categories)
return features.OneHotLabel( output = F.label_to_one_hot(input, num_categories=num_categories)
K.label_to_one_hot(input, num_categories=num_categories), categories=input.categories return features.OneHotLabel(output, categories=input.categories)
) else:
return input
def extra_repr(self) -> str: def extra_repr(self) -> str:
if self.num_categories == -1: if self.num_categories == -1:
......
from typing import Any, Union, Optional from typing import Any, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
......
from ._augment import erase, mixup, cutmix from torchvision.transforms import InterpolationMode # usort: skip
from ._utils import get_image_size, get_image_num_channels # usort: skip
from ._meta_conversion import (
convert_bounding_box_format,
convert_image_color_space_tensor,
convert_image_color_space_pil,
) # usort: skip
from ._augment import (
erase_image_tensor,
mixup_image_tensor,
mixup_one_hot_label,
cutmix_image_tensor,
cutmix_one_hot_label,
)
from ._color import ( from ._color import (
adjust_brightness, adjust_brightness_image_tensor,
adjust_contrast, adjust_brightness_image_pil,
adjust_saturation, adjust_contrast_image_tensor,
adjust_sharpness, adjust_contrast_image_pil,
posterize, adjust_saturation_image_tensor,
solarize, adjust_saturation_image_pil,
autocontrast, adjust_sharpness_image_tensor,
equalize, adjust_sharpness_image_pil,
invert, posterize_image_tensor,
posterize_image_pil,
solarize_image_tensor,
solarize_image_pil,
autocontrast_image_tensor,
autocontrast_image_pil,
equalize_image_tensor,
equalize_image_pil,
invert_image_tensor,
invert_image_pil,
adjust_hue_image_tensor,
adjust_hue_image_pil,
adjust_gamma_image_tensor,
adjust_gamma_image_pil,
)
from ._geometry import (
horizontal_flip_bounding_box,
horizontal_flip_image_tensor,
horizontal_flip_image_pil,
resize_bounding_box,
resize_image_tensor,
resize_image_pil,
resize_segmentation_mask,
center_crop_image_tensor,
center_crop_image_pil,
resized_crop_image_tensor,
resized_crop_image_pil,
affine_image_tensor,
affine_image_pil,
rotate_image_tensor,
rotate_image_pil,
pad_image_tensor,
pad_image_pil,
crop_image_tensor,
crop_image_pil,
perspective_image_tensor,
perspective_image_pil,
vertical_flip_image_tensor,
vertical_flip_image_pil,
) )
from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
from ._meta_conversion import convert_color_space, convert_format from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
from ._misc import normalize, get_image_size, get_image_num_channels
from typing import Any from typing import Tuple
import torch import torch
from torchvision.prototype import features from torchvision.transforms import functional_tensor as _FT
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
erase_image_tensor = _FT.erase
from ._utils import dispatch
def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor:
@dispatch( input = input.clone()
{ return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam))
torch.Tensor: _F.erase,
features.Image: K.erase_image,
} def mixup_image_tensor(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
) if image_batch.ndim < 4:
def erase(input: Any, *args: Any, **kwargs: Any) -> Any: raise ValueError("Need a batch of images")
"""TODO: add docstring"""
... return _mixup_tensor(image_batch, -4, lam)
@dispatch( def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor:
{ if one_hot_label_batch.ndim < 2:
features.Image: K.mixup_image, raise ValueError("Need a batch of one hot labels")
features.OneHotLabel: K.mixup_one_hot_label,
} return _mixup_tensor(one_hot_label_batch, -2, lam)
)
def mixup(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" def cutmix_image_tensor(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor:
... if image_batch.ndim < 4:
raise ValueError("Need a batch of images")
@dispatch( x1, y1, x2, y2 = box
{ image_rolled = image_batch.roll(1, -4)
features.Image: None,
features.OneHotLabel: None, image_batch = image_batch.clone()
} image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
) return image_batch
def cutmix(input: Any, *args: Any, **kwargs: Any) -> Any:
"""Perform the CutMix operation as introduced in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" <https://arxiv.org/abs/1905.04899>`_. def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor:
if one_hot_label_batch.ndim < 2:
Dispatch to the corresponding kernels happens according to this table: raise ValueError("Need a batch of one hot labels")
.. table:: return _mixup_tensor(one_hot_label_batch, -2, lam_adjusted)
:widths: 30 70
==================================================== ================================================================
:class:`~torchvision.prototype.features.Image` :func:`~torch.prototype.transforms.kernels.cutmix_image`
:class:`~torchvision.prototype.features.OneHotLabel` :func:`~torch.prototype.transforms.kernels.cutmix_one_hot_label`
==================================================== ================================================================
Please refer to the kernel documentations for a detailed explanation of the functionality and parameters.
"""
if isinstance(input, features.Image):
kwargs.pop("lam_adjusted", None)
output = K.cutmix_image(input, **kwargs)
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
kwargs.pop("box", None)
output = K.cutmix_one_hot_label(input, **kwargs)
return features.OneHotLabel.new_like(input, output)
raise RuntimeError
from typing import Any from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
import PIL.Image adjust_brightness_image_tensor = _FT.adjust_brightness
import torch adjust_brightness_image_pil = _FP.adjust_brightness
from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
from ._utils import dispatch adjust_saturation_image_tensor = _FT.adjust_saturation
adjust_saturation_image_pil = _FP.adjust_saturation
adjust_contrast_image_tensor = _FT.adjust_contrast
adjust_contrast_image_pil = _FP.adjust_contrast
@dispatch( adjust_sharpness_image_tensor = _FT.adjust_sharpness
{ adjust_sharpness_image_pil = _FP.adjust_sharpness
torch.Tensor: _F.adjust_brightness,
PIL.Image.Image: _F.adjust_brightness,
features.Image: K.adjust_brightness_image,
}
)
def adjust_brightness(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
posterize_image_tensor = _FT.posterize
posterize_image_pil = _FP.posterize
@dispatch( solarize_image_tensor = _FT.solarize
{ solarize_image_pil = _FP.solarize
torch.Tensor: _F.adjust_saturation,
PIL.Image.Image: _F.adjust_saturation,
features.Image: K.adjust_saturation_image,
}
)
def adjust_saturation(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
autocontrast_image_tensor = _FT.autocontrast
autocontrast_image_pil = _FP.autocontrast
@dispatch( equalize_image_tensor = _FT.equalize
{ equalize_image_pil = _FP.equalize
torch.Tensor: _F.adjust_contrast,
PIL.Image.Image: _F.adjust_contrast,
features.Image: K.adjust_contrast_image,
}
)
def adjust_contrast(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
invert_image_tensor = _FT.invert
invert_image_pil = _FP.invert
@dispatch( adjust_hue_image_tensor = _FT.adjust_hue
{ adjust_hue_image_pil = _FP.adjust_hue
torch.Tensor: _F.adjust_sharpness,
PIL.Image.Image: _F.adjust_sharpness,
features.Image: K.adjust_sharpness_image,
}
)
def adjust_sharpness(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
adjust_gamma_image_tensor = _FT.adjust_gamma
@dispatch( adjust_gamma_image_pil = _FP.adjust_gamma
{
torch.Tensor: _F.posterize,
PIL.Image.Image: _F.posterize,
features.Image: K.posterize_image,
}
)
def posterize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.solarize,
PIL.Image.Image: _F.solarize,
features.Image: K.solarize_image,
}
)
def solarize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.autocontrast,
PIL.Image.Image: _F.autocontrast,
features.Image: K.autocontrast_image,
}
)
def autocontrast(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.equalize,
PIL.Image.Image: _F.equalize,
features.Image: K.equalize_image,
}
)
def equalize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.invert,
PIL.Image.Image: _F.invert,
features.Image: K.invert_image,
}
)
def invert(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.adjust_hue,
PIL.Image.Image: _F.adjust_hue,
features.Image: K.adjust_hue_image,
}
)
def adjust_hue(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
@dispatch(
{
torch.Tensor: _F.adjust_gamma,
PIL.Image.Image: _F.adjust_gamma,
features.Image: K.adjust_gamma_image,
}
)
def adjust_gamma(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
from typing import Any import numbers
from typing import Tuple, List, Optional, Sequence, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import kernels as K from torchvision.prototype.transforms import InterpolationMode
from torchvision.transforms import functional as _F from torchvision.prototype.transforms.functional import get_image_size
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
from ._utils import dispatch from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix
from ._meta_conversion import convert_bounding_box_format
@dispatch(
{
torch.Tensor: _F.hflip, horizontal_flip_image_tensor = _FT.hflip
PIL.Image.Image: _F.hflip, horizontal_flip_image_pil = _FP.hflip
features.Image: K.horizontal_flip_image,
features.BoundingBox: None,
}, def horizontal_flip_bounding_box(
) bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int]
def horizontal_flip(input: Any, *args: Any, **kwargs: Any) -> Any: ) -> torch.Tensor:
"""TODO: add docstring""" shape = bounding_box.shape
if isinstance(input, features.BoundingBox):
output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) bounding_box = convert_bounding_box_format(
return features.BoundingBox.new_like(input, output) bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
raise RuntimeError
bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]]
@dispatch( return convert_bounding_box_format(
{ bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
torch.Tensor: _F.resize, ).view(shape)
PIL.Image.Image: _F.resize,
features.Image: K.resize_image,
features.SegmentationMask: K.resize_segmentation_mask, def resize_image_tensor(
features.BoundingBox: None, image: torch.Tensor,
} size: List[int],
) interpolation: InterpolationMode = InterpolationMode.BILINEAR,
def resize(input: Any, *args: Any, **kwargs: Any) -> Any: max_size: Optional[int] = None,
"""TODO: add docstring""" antialias: Optional[bool] = None,
if isinstance(input, features.BoundingBox): ) -> torch.Tensor:
size = kwargs.pop("size") new_height, new_width = size
output = K.resize_bounding_box(input, size=size, image_size=input.image_size) old_width, old_height = _FT.get_image_size(image)
return features.BoundingBox.new_like(input, output, image_size=size) num_channels = _FT.get_image_num_channels(image)
batch_shape = image.shape[:-3]
raise RuntimeError return _FT.resize(
image.reshape((-1, num_channels, old_height, old_width)),
size=size,
@dispatch( interpolation=interpolation.value,
{ max_size=max_size,
torch.Tensor: _F.center_crop, antialias=antialias,
PIL.Image.Image: _F.center_crop, ).reshape(batch_shape + (num_channels, new_height, new_width))
features.Image: K.center_crop_image,
}
) def resize_image_pil(
def center_crop(input: Any, *args: Any, **kwargs: Any) -> Any: img: PIL.Image.Image,
"""TODO: add docstring""" size: Union[Sequence[int], int],
... interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
) -> PIL.Image.Image:
@dispatch( return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation], max_size=max_size)
{
torch.Tensor: _F.resized_crop,
PIL.Image.Image: _F.resized_crop, def resize_segmentation_mask(
features.Image: K.resized_crop_image, segmentation_mask: torch.Tensor, size: List[int], max_size: Optional[int] = None
} ) -> torch.Tensor:
) return resize_image_tensor(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
def resized_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
... # TODO: handle max_size
def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor:
old_height, old_width = image_size
@dispatch( new_height, new_width = size
{ ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
torch.Tensor: _F.affine, return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape)
PIL.Image.Image: _F.affine,
features.Image: K.affine_image,
} vertical_flip_image_tensor = _FT.vflip
) vertical_flip_image_pil = _FP.vflip
def affine(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
... def _affine_parse_args(
angle: float,
translate: List[float],
@dispatch( scale: float,
{ shear: List[float],
torch.Tensor: _F.rotate, interpolation: InterpolationMode = InterpolationMode.NEAREST,
PIL.Image.Image: _F.rotate, center: Optional[List[float]] = None,
features.Image: K.rotate_image, ) -> Tuple[float, List[float], List[float], Optional[List[float]]]:
} if not isinstance(angle, (int, float)):
) raise TypeError("Argument angle should be int or float")
def rotate(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" if not isinstance(translate, (list, tuple)):
... raise TypeError("Argument translate should be a sequence")
if len(translate) != 2:
@dispatch( raise ValueError("Argument translate should be a sequence of length 2")
{
torch.Tensor: _F.pad, if scale <= 0.0:
PIL.Image.Image: _F.pad, raise ValueError("Argument scale should be positive")
features.Image: K.pad_image,
} if not isinstance(shear, (numbers.Number, (list, tuple))):
) raise TypeError("Shear should be either a single value or a sequence of two values")
def pad(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" if not isinstance(interpolation, InterpolationMode):
... raise TypeError("Argument interpolation should be a InterpolationMode")
if isinstance(angle, int):
@dispatch( angle = float(angle)
{
torch.Tensor: _F.crop, if isinstance(translate, tuple):
PIL.Image.Image: _F.crop, translate = list(translate)
features.Image: K.crop_image,
} if isinstance(shear, numbers.Number):
) shear = [shear, 0.0]
def crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" if isinstance(shear, tuple):
... shear = list(shear)
if len(shear) == 1:
@dispatch( shear = [shear[0], shear[0]]
{
torch.Tensor: _F.perspective, if len(shear) != 2:
PIL.Image.Image: _F.perspective, raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
features.Image: K.perspective_image,
} if center is not None and not isinstance(center, (list, tuple)):
) raise TypeError("Argument center should be a sequence")
def perspective(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" return angle, translate, shear, center
...
def affine_image_tensor(
@dispatch( img: torch.Tensor,
{ angle: float,
torch.Tensor: _F.vflip, translate: List[float],
PIL.Image.Image: _F.vflip, scale: float,
features.Image: K.vertical_flip_image, shear: List[float],
} interpolation: InterpolationMode = InterpolationMode.NEAREST,
) fill: Optional[List[float]] = None,
def vertical_flip(input: Any, *args: Any, **kwargs: Any) -> Any: center: Optional[List[float]] = None,
"""TODO: add docstring""" ) -> torch.Tensor:
... angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
center_f = [0.0, 0.0]
@dispatch( if center is not None:
{ width, height = get_image_size(img)
torch.Tensor: _F.five_crop, # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
PIL.Image.Image: _F.five_crop, center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
features.Image: K.five_crop_image,
} translate_f = [1.0 * t for t in translate]
) matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
def five_crop(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring""" return _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill)
...
def affine_image_pil(
@dispatch( img: PIL.Image.Image,
{ angle: float,
torch.Tensor: _F.ten_crop, translate: List[float],
PIL.Image.Image: _F.ten_crop, scale: float,
features.Image: K.ten_crop_image, shear: List[float],
} interpolation: InterpolationMode = InterpolationMode.NEAREST,
) fill: Optional[List[float]] = None,
def ten_crop(input: Any, *args: Any, **kwargs: Any) -> Any: center: Optional[List[float]] = None,
"""TODO: add docstring""" ) -> PIL.Image.Image:
... angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
width, height = get_image_size(img)
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill, center=center)
def rotate_image_tensor(
img: torch.Tensor,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
center_f = [0.0, 0.0]
if center is not None:
width, height = get_image_size(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
return _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill)
def rotate_image_pil(
img: PIL.Image.Image,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
return _FP.rotate(
img, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
)
pad_image_tensor = _FT.pad
pad_image_pil = _FP.pad
crop_image_tensor = _FT.crop
crop_image_pil = _FP.crop
def perspective_image_tensor(
img: torch.Tensor,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> torch.Tensor:
return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill)
def perspective_image_pil(
img: PIL.Image.Image,
perspective_coeffs: float,
interpolation: InterpolationMode = InterpolationMode.BICUBIC,
fill: Optional[List[float]] = None,
) -> PIL.Image.Image:
return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number):
return [int(output_size), int(output_size)]
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
return [output_size[0], output_size[0]]
else:
return list(output_size)
def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]:
return [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
def _center_crop_compute_crop_anchor(
crop_height: int, crop_width: int, image_height: int, image_width: int
) -> Tuple[int, int]:
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop_top, crop_left
def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_tensor(img, padding_ltrb, fill=0)
image_width, image_height = get_image_size(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_tensor(img, crop_top, crop_left, crop_height, crop_width)
def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_pil(img, padding_ltrb, fill=0)
image_width, image_height = get_image_size(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_pil(img, crop_top, crop_left, crop_height, crop_width)
def resized_crop_image_tensor(
img: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> torch.Tensor:
img = crop_image_tensor(img, top, left, height, width)
return resize_image_tensor(img, size, interpolation=interpolation)
def resized_crop_image_pil(
img: PIL.Image.Image,
top: int,
left: int,
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> PIL.Image.Image:
img = crop_image_pil(img, top, left, height, width)
return resize_image_pil(img, size, interpolation=interpolation)
from typing import Any
import PIL.Image import PIL.Image
import torch import torch
from torchvision.ops import box_convert from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.prototype import features from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
from torchvision.prototype.transforms import kernels as K
from torchvision.transforms import functional as _F
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
from ._utils import dispatch xyxy = xywh.clone()
xyxy[..., 2:] += xyxy[..., :2]
return xyxy
@dispatch(
{
torch.Tensor: None, def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
features.BoundingBox: None, xywh = xyxy.clone()
} xywh[..., 2:] -= xywh[..., :2]
) return xywh
def convert_format(input: Any, *args: Any, **kwargs: Any) -> Any:
format = kwargs["format"]
if type(input) is torch.Tensor: def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor:
old_format = kwargs.get("old_format") cx, cy, w, h = torch.unbind(cxcywh, dim=-1)
if old_format is None: x1 = cx - 0.5 * w
raise TypeError("For vanilla tensors the `old_format` needs to be provided.") y1 = cy - 0.5 * h
return box_convert(input, in_fmt=kwargs["old_format"].name.lower(), out_fmt=format.name.lower()) x2 = cx + 0.5 * w
elif isinstance(input, features.BoundingBox): y2 = cy + 0.5 * h
output = K.convert_bounding_box_format(input, old_format=input.format, new_format=kwargs["format"]) return torch.stack((x1, y1, x2, y2), dim=-1)
return features.BoundingBox.new_like(input, output, format=format)
raise RuntimeError def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:
x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1)
cx = (x1 + x2) / 2
@dispatch( cy = (y1 + y2) / 2
{ w = x2 - x1
torch.Tensor: None, h = y2 - y1
PIL.Image.Image: None, return torch.stack((cx, cy, w, h), dim=-1)
features.Image: None,
}
) def convert_bounding_box_format(
def convert_color_space(input: Any, *args: Any, **kwargs: Any) -> Any: bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat
color_space = kwargs["color_space"] ) -> torch.Tensor:
if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): if new_format == old_format:
if color_space != features.ColorSpace.GRAYSCALE: return bounding_box.clone()
raise ValueError("For vanilla tensors and PIL images only RGB to grayscale is supported")
return _F.rgb_to_grayscale(input) if old_format == BoundingBoxFormat.XYWH:
elif isinstance(input, features.Image): bounding_box = _xywh_to_xyxy(bounding_box)
output = K.convert_color_space(input, old_color_space=input.color_space, new_color_space=color_space) elif old_format == BoundingBoxFormat.CXCYWH:
return features.Image.new_like(input, output, color_space=color_space) bounding_box = _cxcywh_to_xyxy(bounding_box)
raise RuntimeError if new_format == BoundingBoxFormat.XYWH:
bounding_box = _xyxy_to_xywh(bounding_box)
elif new_format == BoundingBoxFormat.CXCYWH:
bounding_box = _xyxy_to_cxcywh(bounding_box)
return bounding_box
def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor:
return grayscale.expand(3, 1, 1)
def convert_image_color_space_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor:
if new_color_space == old_color_space:
return image.clone()
if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb_tensor(image)
if new_color_space == ColorSpace.GRAYSCALE:
image = _FT.rgb_to_grayscale(image)
return image
def _grayscale_to_rgb_pil(grayscale: PIL.Image.Image) -> PIL.Image.Image:
return grayscale.convert("RGB")
def convert_image_color_space_pil(
image: PIL.Image.Image, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> PIL.Image.Image:
if new_color_space == old_color_space:
return image.copy()
if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb_pil(image)
if new_color_space == ColorSpace.GRAYSCALE:
image = _FP.to_grayscale(image)
return image
from typing import Any from typing import Optional, List
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.transforms import functional_tensor as _FT
from torchvision.prototype.transforms import kernels as K from torchvision.transforms.functional import to_tensor, to_pil_image
from torchvision.transforms import functional as _F
from torchvision.transforms.functional_pil import (
get_image_size as _get_image_size_pil,
get_image_num_channels as _get_image_num_channels_pil,
)
from torchvision.transforms.functional_tensor import (
get_image_size as _get_image_size_tensor,
get_image_num_channels as _get_image_num_channels_tensor,
)
from ._utils import dispatch
normalize_image_tensor = _FT.normalize
@dispatch(
{
torch.Tensor: _F.normalize,
features.Image: K.normalize_image,
}
)
def normalize(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
def gaussian_blur_image_tensor(
img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2:
raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0:
raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
@dispatch( if sigma is None:
{ sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
torch.Tensor: _F.gaussian_blur,
PIL.Image.Image: _F.gaussian_blur,
features.Image: K.gaussian_blur_image,
}
)
def gaussian_blur(input: Any, *args: Any, **kwargs: Any) -> Any:
"""TODO: add docstring"""
...
if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if len(sigma) != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
for s in sigma:
if s <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}")
@dispatch( return _FT.gaussian_blur(img, kernel_size, sigma)
{
torch.Tensor: _get_image_size_tensor,
PIL.Image.Image: _get_image_size_pil,
features.Image: None,
features.BoundingBox: None,
}
)
def get_image_size(input: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(input, (features.Image, features.BoundingBox)):
return list(input.image_size)
raise RuntimeError
def gaussian_blur_image_pil(img: PIL.Image, kernel_size: List[int], sigma: Optional[List[float]] = None) -> PIL.Image:
@dispatch( return to_pil_image(gaussian_blur_image_tensor(to_tensor(img), kernel_size=kernel_size, sigma=sigma))
{
torch.Tensor: _get_image_num_channels_tensor,
PIL.Image.Image: _get_image_num_channels_pil,
features.Image: None,
}
)
def get_image_num_channels(input: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(input, features.Image):
return input.num_channels
raise RuntimeError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment