Unverified Commit 77c8c91c authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Ported all transforms to the new API (#6305)

* [proto] Added few transforms tests, part 1 (#6262)

* Added supported/unsupported data checks in the tests for cutmix/mixup

* Added RandomRotation, RandomAffine transforms tests

* Added tests for RandomZoomOut, Pad

* Update test_prototype_transforms.py

* Added RandomCrop transform and tests (#6271)

* [proto] Added GaussianBlur transform and tests (#6273)

* Added GaussianBlur transform and tests

* Fixing code format

* Copied correctness test

* [proto] Added random color transforms and tests (#6275)

* Added random color transforms and tests

* Disable smoke test for RandomSolarize, RandomAdjustSharpness

* Added RandomPerspective and tests (#6284)

- replaced real image creation by mocks for other tests

* Added more functional tests (#6285)

* [proto] Added elastic transform and tests (#6295)

* WIP [proto] Added functional elastic transform with tests

* Added more functional tests

* WIP on elastic op

* Added elastic transform and tests

* Added tests

* Added tests for ElasticTransform

* Try to format code as in https://github.com/pytorch/vision/pull/5106



* Fixed bug in affine get_params test

* Implemented RandomErase on PIL input as fallback to tensors (#6309)

Added tests

* Added image_size computation for BoundingBox.rotate if expand (#6319)

* Added image_size computation for BoundingBox.rotate if expand

* Added tests

* Added erase_image_pil and eager/jit erase_image_tensor test (#6320)

* Updates according to the review
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 0ed5d811
...@@ -1352,16 +1352,24 @@ def test_ten_crop(device): ...@@ -1352,16 +1352,24 @@ def test_ten_crop(device):
assert_equal(transformed_batch, s_transformed_batch) assert_equal(transformed_batch, s_transformed_batch)
def test_elastic_transform_asserts():
with pytest.raises(TypeError, match="Argument displacement should be a Tensor"):
_ = F.elastic_transform("abc", displacement=None)
with pytest.raises(TypeError, match="img should be PIL Image or Tensor"):
_ = F.elastic_transform("abc", displacement=torch.rand(1))
img_tensor = torch.rand(1, 3, 32, 24)
with pytest.raises(ValueError, match="Argument displacement shape should"):
_ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2))
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC]) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fill", "fill",
[ [None, [255, 255, 255], (2.0,)],
None,
[255, 255, 255],
(2.0,),
],
) )
def test_elastic_transform_consistency(device, interpolation, dt, fill): def test_elastic_transform_consistency(device, interpolation, dt, fill):
script_elastic_transform = torch.jit.script(F.elastic_transform) script_elastic_transform = torch.jit.script(F.elastic_transform)
......
This diff is collapsed.
import functools import functools
import itertools import itertools
import math import math
import os
import numpy as np import numpy as np
import pytest import pytest
...@@ -58,7 +59,7 @@ def make_images( ...@@ -58,7 +59,7 @@ def make_images(
yield make_image(size, color_space=color_space, dtype=dtype) yield make_image(size, color_space=color_space, dtype=dtype)
for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims): for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype) yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
...@@ -148,12 +149,12 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype ...@@ -148,12 +149,12 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype
def make_segmentation_masks( def make_segmentation_masks(
image_sizes=((16, 16), (7, 33), (31, 9)), sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.long,), dtypes=(torch.long,),
extra_dims=((), (4,), (2, 3)), extra_dims=((), (4,), (2, 3)),
): ):
for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims): for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_) yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
class SampleInput: class SampleInput:
...@@ -199,6 +200,30 @@ def horizontal_flip_bounding_box(): ...@@ -199,6 +200,30 @@ def horizontal_flip_bounding_box():
yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_segmentation_mask():
for mask in make_segmentation_masks():
yield SampleInput(mask)
@register_kernel_info_from_sample_inputs_fn
def vertical_flip_image_tensor():
for image in make_images():
yield SampleInput(image)
@register_kernel_info_from_sample_inputs_fn
def vertical_flip_bounding_box():
for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]):
yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size)
@register_kernel_info_from_sample_inputs_fn
def vertical_flip_segmentation_mask():
for mask in make_segmentation_masks():
yield SampleInput(mask)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def resize_image_tensor(): def resize_image_tensor():
for image, interpolation, max_size, antialias in itertools.product( for image, interpolation, max_size, antialias in itertools.product(
...@@ -403,9 +428,17 @@ def crop_segmentation_mask(): ...@@ -403,9 +428,17 @@ def crop_segmentation_mask():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def vertical_flip_segmentation_mask(): def resized_crop_image_tensor():
for mask in make_segmentation_masks(): for mask, top, left, height, width, size, antialias in itertools.product(
yield SampleInput(mask) make_images(),
[-8, 9],
[-8, 9],
[12],
[12],
[(16, 18)],
[True, False],
):
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size, antialias=antialias)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -456,6 +489,19 @@ def pad_bounding_box(): ...@@ -456,6 +489,19 @@ def pad_bounding_box():
yield SampleInput(bounding_box, padding=padding, format=bounding_box.format) yield SampleInput(bounding_box, padding=padding, format=bounding_box.format)
@register_kernel_info_from_sample_inputs_fn
def perspective_image_tensor():
for image, perspective_coeffs, fill in itertools.product(
make_images(extra_dims=((), (4,))),
[
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
],
[None, [128], [12.0]], # fill
):
yield SampleInput(image, perspective_coeffs=perspective_coeffs, fill=fill)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def perspective_bounding_box(): def perspective_bounding_box():
for bounding_box, perspective_coeffs in itertools.product( for bounding_box, perspective_coeffs in itertools.product(
...@@ -487,13 +533,47 @@ def perspective_segmentation_mask(): ...@@ -487,13 +533,47 @@ def perspective_segmentation_mask():
) )
@register_kernel_info_from_sample_inputs_fn
def elastic_image_tensor():
for image, fill in itertools.product(
make_images(extra_dims=((), (4,))),
[None, [128], [12.0]], # fill
):
h, w = image.shape[-2:]
displacement = torch.rand(1, h, w, 2)
yield SampleInput(image, displacement=displacement, fill=fill)
@register_kernel_info_from_sample_inputs_fn
def elastic_bounding_box():
for bounding_box in make_bounding_boxes():
h, w = bounding_box.image_size
displacement = torch.rand(1, h, w, 2)
yield SampleInput(
bounding_box,
format=bounding_box.format,
displacement=displacement,
)
@register_kernel_info_from_sample_inputs_fn
def elastic_segmentation_mask():
for mask in make_segmentation_masks(extra_dims=((), (4,))):
h, w = mask.shape[-2:]
displacement = torch.rand(1, h, w, 2)
yield SampleInput(
mask,
displacement=displacement,
)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def center_crop_image_tensor(): def center_crop_image_tensor():
for image, output_size in itertools.product( for mask, output_size in itertools.product(
make_images(sizes=((16, 16), (7, 33), (31, 9))), make_images(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
): ):
yield SampleInput(image, output_size) yield SampleInput(mask, output_size)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
...@@ -507,12 +587,80 @@ def center_crop_bounding_box(): ...@@ -507,12 +587,80 @@ def center_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def center_crop_segmentation_mask(): def center_crop_segmentation_mask():
for mask, output_size in itertools.product( for mask, output_size in itertools.product(
make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))), make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
): ):
yield SampleInput(mask, output_size) yield SampleInput(mask, output_size)
@register_kernel_info_from_sample_inputs_fn
def gaussian_blur_image_tensor():
for image, kernel_size, sigma in itertools.product(
make_images(extra_dims=((4,),)),
[[3, 3]],
[None, [3.0, 3.0]],
):
yield SampleInput(image, kernel_size=kernel_size, sigma=sigma)
@register_kernel_info_from_sample_inputs_fn
def equalize_image_tensor():
for image in make_images(extra_dims=(), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
if image.dtype != torch.uint8:
continue
yield SampleInput(image)
@register_kernel_info_from_sample_inputs_fn
def invert_image_tensor():
for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
yield SampleInput(image)
@register_kernel_info_from_sample_inputs_fn
def posterize_image_tensor():
for image, bits in itertools.product(
make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)),
[1, 4, 8],
):
if image.dtype != torch.uint8:
continue
yield SampleInput(image, bits=bits)
@register_kernel_info_from_sample_inputs_fn
def solarize_image_tensor():
for image, threshold in itertools.product(
make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)),
[0.1, 0.5, 127.0],
):
if image.is_floating_point() and threshold > 1.0:
continue
yield SampleInput(image, threshold=threshold)
@register_kernel_info_from_sample_inputs_fn
def autocontrast_image_tensor():
for image in make_images(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)):
yield SampleInput(image)
@register_kernel_info_from_sample_inputs_fn
def adjust_sharpness_image_tensor():
for image, sharpness_factor in itertools.product(
make_images(extra_dims=((4,),), color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)),
[0.1, 0.5],
):
yield SampleInput(image, sharpness_factor=sharpness_factor)
@register_kernel_info_from_sample_inputs_fn
def erase_image_tensor():
for image in make_images():
c = image.shape[-3]
yield SampleInput(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7))
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kernel", "kernel",
[ [
...@@ -546,9 +694,19 @@ def test_scriptable(kernel): ...@@ -546,9 +694,19 @@ def test_scriptable(kernel):
and all( and all(
feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"} feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"}
) )
and name not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate"} and name
not in {
"to_image_tensor",
"InterpolationMode",
"decode_video_with_av",
"crop",
"perspective",
"elastic_transform",
"elastic",
}
# We skip 'crop' due to missing 'height' and 'width' # We skip 'crop' due to missing 'height' and 'width'
# We skip 'rotate' due to non implemented yet expand=True case for bboxes # We skip 'perspective' as it requires different input args than perspective_image_tensor etc
# Skip 'elastic', TODO: inspect why test is failing
], ],
) )
def test_functional_mid_level(func): def test_functional_mid_level(func):
...@@ -561,7 +719,9 @@ def test_functional_mid_level(func): ...@@ -561,7 +719,9 @@ def test_functional_mid_level(func):
if key in kwargs: if key in kwargs:
del kwargs[key] del kwargs[key]
output = func(*sample_input.args, **kwargs) output = func(*sample_input.args, **kwargs)
torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}") torch.testing.assert_close(
output, expected, msg=f"finfo={finfo.name}, output={output}, expected={expected}"
)
break break
...@@ -844,6 +1004,9 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -844,6 +1004,9 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
out_bbox[2] -= tr_x out_bbox[2] -= tr_x
out_bbox[3] -= tr_y out_bbox[3] -= tr_y
# image_size should be updated, but it is OK here to skip its computation
# as we do not compute it in F.rotate_bounding_box
out_bbox = features.BoundingBox( out_bbox = features.BoundingBox(
out_bbox, out_bbox,
format=features.BoundingBoxFormat.XYXY, format=features.BoundingBoxFormat.XYXY,
...@@ -1126,6 +1289,18 @@ def test_correctness_crop_segmentation_mask(device, top, left, height, width): ...@@ -1126,6 +1289,18 @@ def test_correctness_crop_segmentation_mask(device, top, left, height, width):
torch.testing.assert_close(output_mask, expected_mask) torch.testing.assert_close(output_mask, expected_mask)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device):
mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
mask[:, :, 0] = 1
out_mask = F.horizontal_flip_segmentation_mask(mask)
expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
expected_mask[:, :, -1] = 1
torch.testing.assert_close(out_mask, expected_mask)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
...@@ -1565,3 +1740,102 @@ def test_correctness_center_crop_segmentation_mask(device, output_size): ...@@ -1565,3 +1740,102 @@ def test_correctness_center_crop_segmentation_mask(device, output_size):
expected = _compute_expected_segmentation_mask(mask, output_size) expected = _compute_expected_segmentation_mask(mask, output_size)
torch.testing.assert_close(expected, actual) torch.testing.assert_close(expected, actual)
# Copied from test/test_functional_tensor.py
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("image_size", ("small", "large"))
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, sigma):
fn = F.gaussian_blur_image_tensor
# true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
# "3_3_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
# "3_3_0.5": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
# "3_5_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
# "3_5_0.5": ...
# # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
# # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
# "23_23_1.7": ...
# }
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
true_cv2_results = torch.load(p)
if image_size == "small":
tensor = (
torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
)
else:
tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)
if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return
if dt is not None:
tensor = tensor.to(dtype=dt)
_ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
_sigma = sigma[0] if sigma is not None else None
shape = tensor.shape
gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
if gt_key not in true_cv2_results:
return
true_out = (
torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
)
image = features.Image(tensor)
out = fn(image, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)]
)
def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
in_box = [10, 15, 25, 35]
for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))):
c, h, w = sample.shape[-3:]
# Setup a dummy image with 4 points
sample[..., in_box[1], in_box[0]] = torch.tensor([12, 34, 96, 112])[:c]
sample[..., in_box[3] - 1, in_box[0]] = torch.tensor([12, 34, 96, 112])[:c]
sample[..., in_box[3] - 1, in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c]
sample[..., in_box[1], in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c]
sample = sample.to(device)
if fn == F.elastic_image_tensor:
sample = features.Image(sample)
kwargs = {"interpolation": F.InterpolationMode.NEAREST}
else:
sample = features.SegmentationMask(sample)
kwargs = {}
# Create a displacement grid using sin
n, m = 5.0, 0.1
d1 = m * torch.sin(torch.arange(h, dtype=torch.float) * torch.pi * n / h)
d2 = m * torch.sin(torch.arange(w, dtype=torch.float) * torch.pi * n / w)
d1 = d1[:, None].expand((h, w))
d2 = d2[None, :].expand((h, w))
displacement = torch.cat([d1[..., None], d2[..., None]], dim=-1)
displacement = displacement.reshape(1, h, w, 2)
output = fn(sample, displacement=displacement, **kwargs)
# Check places where transformed points should be
torch.testing.assert_close(output[..., 12, 9], sample[..., in_box[1], in_box[0]])
torch.testing.assert_close(output[..., 17, 27], sample[..., in_box[1], in_box[2] - 1])
torch.testing.assert_close(output[..., 31, 6], sample[..., in_box[3] - 1, in_box[0]])
torch.testing.assert_close(output[..., 37, 23], sample[..., in_box[3] - 1, in_box[2] - 1])
...@@ -5,6 +5,8 @@ from typing import Any, List, Optional, Sequence, Tuple, Union ...@@ -5,6 +5,8 @@ from typing import Any, List, Optional, Sequence, Tuple, Union
import torch import torch
from torchvision._utils import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import _get_inverse_affine_matrix
from torchvision.transforms.functional_tensor import _compute_output_size
from ._feature import _Feature from ._feature import _Feature
...@@ -168,10 +170,18 @@ class BoundingBox(_Feature): ...@@ -168,10 +170,18 @@ class BoundingBox(_Feature):
output = _F.rotate_bounding_box( output = _F.rotate_bounding_box(
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center
) )
# TODO: update output image size if expand is True image_size = self.image_size
if expand: if expand:
raise RuntimeError("Not yet implemented") # The way we recompute image_size is not optimal due to redundant computations of
return BoundingBox.new_like(self, output, dtype=output.dtype) # - rotation matrix (_get_inverse_affine_matrix)
# - points dot matrix (_compute_output_size)
# Alternatively, we could return new image size by _F.rotate_bounding_box
height, width = image_size
rotation_matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, [0.0, 0.0], 1.0, [0.0, 0.0])
new_width, new_height = _compute_output_size(rotation_matrix, width, height)
image_size = (new_height, new_width)
return BoundingBox.new_like(self, output, dtype=output.dtype, image_size=image_size)
def affine( def affine(
self, self,
...@@ -207,3 +217,14 @@ class BoundingBox(_Feature): ...@@ -207,3 +217,14 @@ class BoundingBox(_Feature):
output = _F.perspective_bounding_box(self, self.format, perspective_coeffs) output = _F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype) return BoundingBox.new_like(self, output, dtype=output.dtype)
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F
output = _F.elastic_bounding_box(self, self.format, displacement)
return BoundingBox.new_like(self, output, dtype=output.dtype)
...@@ -157,6 +157,14 @@ class _Feature(torch.Tensor): ...@@ -157,6 +157,14 @@ class _Feature(torch.Tensor):
) -> Any: ) -> Any:
return self return self
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Any:
return self
def adjust_brightness(self, brightness_factor: float) -> Any: def adjust_brightness(self, brightness_factor: float) -> Any:
return self return self
...@@ -189,3 +197,6 @@ class _Feature(torch.Tensor): ...@@ -189,3 +197,6 @@ class _Feature(torch.Tensor):
def invert(self) -> Any: def invert(self) -> Any:
return self return self
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Any:
return self
...@@ -74,7 +74,7 @@ class Image(_Feature): ...@@ -74,7 +74,7 @@ class Image(_Feature):
@property @property
def image_size(self) -> Tuple[int, int]: def image_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], self.shape[-2:]) return cast(Tuple[int, int], tuple(self.shape[-2:]))
@property @property
def num_channels(self) -> int: def num_channels(self) -> int:
...@@ -243,6 +243,19 @@ class Image(_Feature): ...@@ -243,6 +243,19 @@ class Image(_Feature):
output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill) output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill)
return Image.new_like(self, output) return Image.new_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image:
from torchvision.prototype.transforms.functional import _geometry as _F
fill = _F._convert_fill_arg(fill)
output = _F.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
return Image.new_like(self, output)
def adjust_brightness(self, brightness_factor: float) -> Image: def adjust_brightness(self, brightness_factor: float) -> Image:
from torchvision.prototype.transforms import functional as _F from torchvision.prototype.transforms import functional as _F
...@@ -308,3 +321,9 @@ class Image(_Feature): ...@@ -308,3 +321,9 @@ class Image(_Feature):
output = _F.invert_image_tensor(self) output = _F.invert_image_tensor(self)
return Image.new_like(self, output) return Image.new_like(self, output)
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
from torchvision.prototype.transforms import functional as _F
output = _F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma)
return Image.new_like(self, output)
...@@ -2,6 +2,7 @@ from __future__ import annotations ...@@ -2,6 +2,7 @@ from __future__ import annotations
from typing import List, Optional, Sequence, Union from typing import List, Optional, Sequence, Union
import torch
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from ._feature import _Feature from ._feature import _Feature
...@@ -119,3 +120,14 @@ class SegmentationMask(_Feature): ...@@ -119,3 +120,14 @@ class SegmentationMask(_Feature):
output = _F.perspective_segmentation_mask(self, perspective_coeffs) output = _F.perspective_segmentation_mask(self, perspective_coeffs)
return SegmentationMask.new_like(self, output) return SegmentationMask.new_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F
output = _F.elastic_segmentation_mask(self, displacement)
return SegmentationMask.new_like(self, output, dtype=output.dtype)
...@@ -4,15 +4,27 @@ from ._transform import Transform # usort: skip ...@@ -4,15 +4,27 @@ from ._transform import Transform # usort: skip
from ._augment import RandomCutmix, RandomErasing, RandomMixup from ._augment import RandomCutmix, RandomErasing, RandomMixup
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import ColorJitter, RandomEqualize, RandomPhotometricDistort from ._color import (
ColorJitter,
RandomAdjustSharpness,
RandomAutocontrast,
RandomEqualize,
RandomInvert,
RandomPhotometricDistort,
RandomPosterize,
RandomSolarize,
)
from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import ( from ._geometry import (
BatchMultiCrop, BatchMultiCrop,
CenterCrop, CenterCrop,
ElasticTransform,
FiveCrop, FiveCrop,
Pad, Pad,
RandomAffine, RandomAffine,
RandomCrop,
RandomHorizontalFlip, RandomHorizontalFlip,
RandomPerspective,
RandomResizedCrop, RandomResizedCrop,
RandomRotation, RandomRotation,
RandomVerticalFlip, RandomVerticalFlip,
...@@ -21,7 +33,7 @@ from ._geometry import ( ...@@ -21,7 +33,7 @@ from ._geometry import (
TenCrop, TenCrop,
) )
from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype
from ._misc import Identity, Lambda, Normalize, ToDtype from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype
from ._type_conversion import DecodeImage, LabelToOneHot from ._type_conversion import DecodeImage, LabelToOneHot
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip
...@@ -92,8 +92,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -92,8 +92,7 @@ class RandomErasing(_RandomApplyTransform):
return features.Image.new_like(inpt, output) return features.Image.new_like(inpt, output)
return output return output
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
# TODO: We should implement a fallback to tensor, like gaussian_blur etc return F.erase_image_pil(inpt, **params)
raise RuntimeError("Not implemented")
else: else:
return inpt return inpt
......
...@@ -151,8 +151,42 @@ class RandomPhotometricDistort(Transform): ...@@ -151,8 +151,42 @@ class RandomPhotometricDistort(Transform):
class RandomEqualize(_RandomApplyTransform): class RandomEqualize(_RandomApplyTransform):
def __init__(self, p: float = 0.5): def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt)
class RandomInvert(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.invert(inpt)
class RandomPosterize(_RandomApplyTransform):
def __init__(self, bits: int, p: float = 0.5) -> None:
super().__init__(p=p) super().__init__(p=p)
self.bits = bits
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.equalize(inpt) return F.posterize(inpt, bits=self.bits)
class RandomSolarize(_RandomApplyTransform):
def __init__(self, threshold: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.threshold = threshold
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.solarize(inpt, threshold=self.threshold)
class RandomAutocontrast(_RandomApplyTransform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.autocontrast(inpt)
class RandomAdjustSharpness(_RandomApplyTransform):
def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.sharpness_factor = sharpness_factor
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor)
...@@ -35,7 +35,8 @@ class Resize(Transform): ...@@ -35,7 +35,8 @@ class Resize(Transform):
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.size = [size] if isinstance(size, int) else list(size)
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.interpolation = interpolation self.interpolation = interpolation
self.max_size = max_size self.max_size = max_size
self.antialias = antialias self.antialias = antialias
...@@ -80,7 +81,6 @@ class RandomResizedCrop(Transform): ...@@ -80,7 +81,6 @@ class RandomResizedCrop(Transform):
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)") warnings.warn("Scale and ratio should be of kind (min, max)")
self.size = size
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
self.interpolation = interpolation self.interpolation = interpolation
...@@ -225,6 +225,21 @@ def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> ...@@ -225,6 +225,21 @@ def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) ->
raise TypeError("Got inappropriate fill arg") raise TypeError("Got inappropriate fill arg")
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
class Pad(Transform): class Pad(Transform):
def __init__( def __init__(
self, self,
...@@ -233,18 +248,10 @@ class Pad(Transform): ...@@ -233,18 +248,10 @@ class Pad(Transform):
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
_check_padding_arg(padding)
_check_fill_arg(fill) _check_fill_arg(fill)
_check_padding_mode_arg(padding_mode)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
self.padding = padding self.padding = padding
self.fill = fill self.fill = fill
...@@ -258,7 +265,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -258,7 +265,7 @@ class RandomZoomOut(_RandomApplyTransform):
def __init__( def __init__(
self, self,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
side_range: Tuple[float, float] = (1.0, 4.0), side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
super().__init__(p=p) super().__init__(p=p)
...@@ -266,6 +273,8 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -266,6 +273,8 @@ class RandomZoomOut(_RandomApplyTransform):
_check_fill_arg(fill) _check_fill_arg(fill)
self.fill = fill self.fill = fill
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
self.side_range = side_range self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]: if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.") raise ValueError(f"Invalid canvas side range provided {side_range}.")
...@@ -285,6 +294,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -285,6 +294,7 @@ class RandomZoomOut(_RandomApplyTransform):
bottom = canvas_height - (top + orig_h) bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom] padding = [left, top, right, bottom]
# vfdev-5: Can we put that into pad_image_tensor ?
fill = self.fill fill = self.fill
if not isinstance(fill, collections.abc.Sequence): if not isinstance(fill, collections.abc.Sequence):
fill = [fill] * orig_c fill = [fill] * orig_c
...@@ -414,3 +424,203 @@ class RandomAffine(Transform): ...@@ -414,3 +424,203 @@ class RandomAffine(Transform):
fill=self.fill, fill=self.fill,
center=self.center, center=self.center,
) )
class RandomCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.padding_mode = padding_mode
self._pad_op = None
if self.padding is not None:
self._pad_op = Pad(self.padding, fill=self.fill, padding_mode=self.padding_mode)
if self.pad_if_needed:
self._pad_op = Pad(0, fill=self.fill, padding_mode=self.padding_mode)
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
_, height, width = get_image_dimensions(image)
output_height, output_width = self.size
if height + 1 < output_height or width + 1 < output_width:
raise ValueError(
f"Required crop size {(output_height, output_width)} is larger then input image size {(height, width)}"
)
if width == output_width and height == output_height:
return dict(top=0, left=0, height=height, width=width)
top = torch.randint(0, height - output_height + 1, size=(1,)).item()
left = torch.randint(0, width - output_width + 1, size=(1,)).item()
return dict(top=top, left=left, height=output_height, width=output_width)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.crop(inpt, **params)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if self._pad_op is not None:
sample = self._pad_op(sample)
image = query_image(sample)
_, height, width = get_image_dimensions(image)
if self.pad_if_needed:
# This check is to explicitly ensure that self._pad_op is defined
if self._pad_op is None:
raise RuntimeError(
"Internal error, self._pad_op is None. "
"Please, fill an issue about that on https://github.com/pytorch/vision/issues"
)
# pad the width if needed
if width < self.size[1]:
self._pad_op.padding = [self.size[1] - width, 0]
sample = self._pad_op(sample)
# pad the height if needed
if height < self.size[0]:
self._pad_op.padding = [0, self.size[0] - height]
sample = self._pad_op(sample)
return super().forward(sample)
class RandomPerspective(_RandomApplyTransform):
def __init__(
self,
distortion_scale: float,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
p: float = 0.5,
) -> None:
super().__init__(p=p)
_check_fill_arg(fill)
if not (0 <= distortion_scale <= 1):
raise ValueError("Argument distortion_scale value should be between 0 and 1")
self.distortion_scale = distortion_scale
self.interpolation = interpolation
self.fill = fill
def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size
# TODO: make it work with bboxes and segm masks
image = query_image(sample)
_, height, width = get_image_dimensions(image)
distortion_scale = self.distortion_scale
half_height = height // 2
half_width = width // 2
topleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
]
topright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()),
]
botright = [
int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
]
botleft = [
int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()),
int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()),
]
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
endpoints = [topleft, topright, botright, botleft]
return dict(startpoints=startpoints, endpoints=endpoints)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.perspective(
inpt,
**params,
fill=self.fill,
interpolation=self.interpolation,
)
def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
if not isinstance(arg, (float, Sequence)):
raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
if isinstance(arg, Sequence) and len(arg) != req_size:
raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
if isinstance(arg, Sequence):
for element in arg:
if not isinstance(element, float):
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")
if isinstance(arg, float):
arg = [float(arg), float(arg)]
if isinstance(arg, (list, tuple)) and len(arg) == 1:
arg = [arg[0], arg[0]]
return arg
class ElasticTransform(Transform):
def __init__(
self,
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
_check_fill_arg(fill)
self.interpolation = interpolation
self.fill = fill
def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size
# TODO: make it work with bboxes and segm masks
image = query_image(sample)
_, *size = get_image_dimensions(image)
dx = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[0] > 0.0:
kx = int(8 * self.sigma[0] + 1)
# if kernel size is even we have to make it odd
if kx % 2 == 0:
kx += 1
dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma))
dx = dx * self.alpha[0] / size[0]
dy = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[1] > 0.0:
ky = int(8 * self.sigma[1] + 1)
# if kernel size is even we have to make it odd
if ky % 2 == 0:
ky += 1
dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma))
dy = dy * self.alpha[1] / size[1]
displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2
return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.elastic(
inpt,
**params,
fill=self.fill,
interpolation=self.interpolation,
)
import functools import functools
from typing import Any, Callable, Dict, List, Type from typing import Any, Callable, Dict, List, Sequence, Type, Union
import torch import torch
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.transforms import _setup_size
class Identity(Transform): class Identity(Transform):
...@@ -46,6 +47,36 @@ class Normalize(Transform): ...@@ -46,6 +47,36 @@ class Normalize(Transform):
return input return input
class GaussianBlur(Transform):
def __init__(
self, kernel_size: Union[int, Sequence[int]], sigma: Union[float, Sequence[float]] = (0.1, 2.0)
) -> None:
super().__init__()
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
for ks in self.kernel_size:
if ks <= 0 or ks % 2 == 0:
raise ValueError("Kernel size value should be an odd and positive number.")
if isinstance(sigma, float):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = (sigma, sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise TypeError("sigma should be a single float or a list/tuple with length 2 floats.")
self.sigma = sigma
def _get_params(self, sample: Any) -> Dict[str, Any]:
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
return dict(sigma=[sigma, sigma])
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.gaussian_blur(inpt, **params)
class ToDtype(Lambda): class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None: def __init__(self, dtype: torch.dtype, *types: Type) -> None:
self.dtype = dtype self.dtype = dtype
......
import enum import enum
import functools
from typing import Any, Dict from typing import Any, Dict
import PIL.Image
import torch import torch
from torch import nn from torch import nn
from torchvision.prototype.utils._internal import apply_recursively from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype.features import _Feature
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
...@@ -16,12 +17,20 @@ class Transform(nn.Module): ...@@ -16,12 +17,20 @@ class Transform(nn.Module):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict() return dict()
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
raise NotImplementedError raise NotImplementedError
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
return apply_recursively(functools.partial(self._transform, params=self._get_params(sample)), sample)
params = self._get_params(sample)
flat_inputs, spec = tree_flatten(sample)
transformed_types = (torch.Tensor, _Feature, PIL.Image.Image)
flat_outputs = [
self._transform(inpt, params) if isinstance(inpt, transformed_types) else inpt for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
def extra_repr(self) -> str: def extra_repr(self) -> str:
extra = [] extra = []
......
from typing import Any, Iterator, Optional, Tuple, Type, Union from typing import Any, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.utils._internal import query_recursively
from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
def fn( flat_sample, _ = tree_flatten(sample)
id: Tuple[Any, ...], input: Any for i in flat_sample:
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: if type(i) == torch.Tensor or isinstance(i, (PIL.Image.Image, features.Image)):
if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): return i
return id, input
return None
try:
return next(query_recursively(fn, sample))[1]
except StopIteration:
raise TypeError("No image was found in the sample") raise TypeError("No image was found in the sample")
...@@ -36,16 +30,14 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im ...@@ -36,16 +30,14 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im
return channels, height, width return channels, height, width
def _extract_types(sample: Any) -> Iterator[Type]:
return query_recursively(lambda id, input: type(input), sample)
def has_any(sample: Any, *types: Type) -> bool: def has_any(sample: Any, *types: Type) -> bool:
return any(issubclass(type, types) for type in _extract_types(sample)) flat_sample, _ = tree_flatten(sample)
return any(issubclass(type(obj), types) for obj in flat_sample)
def has_all(sample: Any, *types: Type) -> bool: def has_all(sample: Any, *types: Type) -> bool:
return not bool(set(types) - set(_extract_types(sample))) flat_sample, _ = tree_flatten(sample)
return not bool(set(types) - set([type(obj) for obj in flat_sample]))
def is_simple_tensor(input: Any) -> bool: def is_simple_tensor(input: Any) -> bool:
......
...@@ -5,7 +5,7 @@ from ._meta import ( ...@@ -5,7 +5,7 @@ from ._meta import (
convert_image_color_space_pil, convert_image_color_space_pil,
) # usort: skip ) # usort: skip
from ._augment import erase_image_tensor from ._augment import erase_image_pil, erase_image_tensor
from ._color import ( from ._color import (
adjust_brightness, adjust_brightness,
adjust_brightness_image_pil, adjust_brightness_image_pil,
...@@ -57,6 +57,12 @@ from ._geometry import ( ...@@ -57,6 +57,12 @@ from ._geometry import (
crop_image_pil, crop_image_pil,
crop_image_tensor, crop_image_tensor,
crop_segmentation_mask, crop_segmentation_mask,
elastic,
elastic_bounding_box,
elastic_image_pil,
elastic_image_tensor,
elastic_segmentation_mask,
elastic_transform,
five_crop_image_pil, five_crop_image_pil,
five_crop_image_tensor, five_crop_image_tensor,
horizontal_flip, horizontal_flip,
...@@ -97,7 +103,7 @@ from ._geometry import ( ...@@ -97,7 +103,7 @@ from ._geometry import (
vertical_flip_image_tensor, vertical_flip_image_tensor,
vertical_flip_segmentation_mask, vertical_flip_segmentation_mask,
) )
from ._misc import gaussian_blur_image_tensor, normalize_image_tensor from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize_image_tensor
from ._type_conversion import ( from ._type_conversion import (
decode_image_with_pil, decode_image_with_pil,
decode_video_with_av, decode_video_with_av,
......
import PIL.Image
import torch
from torchvision.transforms import functional_tensor as _FT from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
erase_image_tensor = _FT.erase erase_image_tensor = _FT.erase
# TODO: Don't forget to clean up from the primitives kernels those that shouldn't be kernels. def erase_image_pil(
# Like the mixup and cutmix stuff img: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image:
# This function is copy-pasted to Image and OneHotLabel and may be refactored t_img = pil_to_tensor(img)
# def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
# input = input.clone() return to_pil_image(output, mode=img.mode)
# return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam))
...@@ -9,8 +9,11 @@ from torchvision.transforms import functional_pil as _FP, functional_tensor as _ ...@@ -9,8 +9,11 @@ from torchvision.transforms import functional_pil as _FP, functional_tensor as _
from torchvision.transforms.functional import ( from torchvision.transforms.functional import (
_compute_output_size, _compute_output_size,
_get_inverse_affine_matrix, _get_inverse_affine_matrix,
_get_perspective_coeffs,
InterpolationMode, InterpolationMode,
pil_modes_mapping, pil_modes_mapping,
pil_to_tensor,
to_pil_image,
) )
from ._meta import convert_bounding_box_format, get_dimensions_image_pil, get_dimensions_image_tensor from ._meta import convert_bounding_box_format, get_dimensions_image_pil, get_dimensions_image_tensor
...@@ -759,16 +762,21 @@ def perspective_bounding_box( ...@@ -759,16 +762,21 @@ def perspective_bounding_box(
).view(original_shape) ).view(original_shape)
def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: def perspective_segmentation_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor:
return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST) return perspective_image_tensor(
mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST
)
def perspective( def perspective(
inpt: DType, inpt: DType,
perspective_coeffs: List[float], startpoints: List[List[int]],
endpoints: List[List[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> DType: ) -> DType:
perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints)
if isinstance(inpt, features._Feature): if isinstance(inpt, features._Feature):
return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill) return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill)
elif isinstance(inpt, PIL.Image.Image): elif isinstance(inpt, PIL.Image.Image):
...@@ -779,6 +787,91 @@ def perspective( ...@@ -779,6 +787,91 @@ def perspective(
return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill)
def elastic_image_tensor(
img: torch.Tensor,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> torch.Tensor:
return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill)
def elastic_image_pil(
img: PIL.Image.Image,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> PIL.Image.Image:
t_img = pil_to_tensor(img)
fill = _convert_fill_arg(fill)
output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
return to_pil_image(output, mode=img.mode)
def elastic_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
displacement: torch.Tensor,
) -> torch.Tensor:
# TODO: add in docstring about approximation we are doing for grid inversion
displacement = displacement.to(bounding_box.device)
original_shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
# Question (vfdev-5): should we rely on good displacement shape and fetch image size from it
# Or add image_size arg and check displacement shape
image_size = displacement.shape[-3], displacement.shape[-2]
id_grid = _FT._create_identity_grid(list(image_size)).to(bounding_box.device)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid
inv_grid = id_grid - displacement
# Get points from bboxes
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long)
index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long)
# Transform points:
t_size = torch.tensor(image_size[::-1], device=displacement.device, dtype=displacement.dtype)
transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5
transformed_points = transformed_points.view(-1, 4, 2)
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
return convert_bounding_box_format(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)
def elastic_segmentation_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor:
return elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST)
def elastic(
inpt: DType,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> DType:
if isinstance(inpt, features._Feature):
return inpt.elastic(displacement, interpolation=interpolation, fill=fill)
elif isinstance(inpt, PIL.Image.Image):
return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill)
else:
fill = _convert_fill_arg(fill)
return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill)
elastic_transform = elastic
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number): if isinstance(output_size, numbers.Number):
return [int(output_size), int(output_size)] return [int(output_size), int(output_size)]
......
from typing import List, Optional from typing import List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
# shortcut type
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
normalize_image_tensor = _FT.normalize normalize_image_tensor = _FT.normalize
def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType:
if isinstance(inpt, features.Image):
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
elif type(inpt) == torch.Tensor:
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
else:
raise TypeError("Unsupported input type")
def gaussian_blur_image_tensor( def gaussian_blur_image_tensor(
img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -42,3 +56,12 @@ def gaussian_blur_image_pil(img: PIL.Image, kernel_size: List[int], sigma: Optio ...@@ -42,3 +56,12 @@ def gaussian_blur_image_pil(img: PIL.Image, kernel_size: List[int], sigma: Optio
t_img = pil_to_tensor(img) t_img = pil_to_tensor(img)
output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma) output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma)
return to_pil_image(output, mode=img.mode) return to_pil_image(output, mode=img.mode)
def gaussian_blur(inpt: DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> DType:
if isinstance(inpt, features._Feature):
return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma)
elif isinstance(inpt, PIL.Image.Image):
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
else:
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)
...@@ -1554,7 +1554,7 @@ def elastic_transform( ...@@ -1554,7 +1554,7 @@ def elastic_transform(
If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions. where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
displacement (Tensor): The displacement field. displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2].
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. :class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. Default is ``InterpolationMode.BILINEAR``.
...@@ -1576,7 +1576,7 @@ def elastic_transform( ...@@ -1576,7 +1576,7 @@ def elastic_transform(
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
if not isinstance(displacement, torch.Tensor): if not isinstance(displacement, torch.Tensor):
raise TypeError("displacement should be a Tensor") raise TypeError("Argument displacement should be a Tensor")
t_img = img t_img = img
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
...@@ -1584,6 +1584,15 @@ def elastic_transform( ...@@ -1584,6 +1584,15 @@ def elastic_transform(
raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}") raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}")
t_img = pil_to_tensor(img) t_img = pil_to_tensor(img)
shape = t_img.shape
shape = (1,) + shape[-2:] + (2,)
if shape != displacement.shape:
raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}")
# TODO: if image shape is [N1, N2, ..., C, H, W] and
# displacement is [1, H, W, 2] we need to reshape input image
# such grid_sampler takes internal code for 4D input
output = F_t.elastic_transform( output = F_t.elastic_transform(
t_img, t_img,
displacement, displacement,
......
...@@ -260,7 +260,7 @@ def _parse_fill( ...@@ -260,7 +260,7 @@ def _parse_fill(
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]: ) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
# Process fill color for affine transforms # Process fill color for affine transforms
num_bands = len(img.getbands()) num_bands = get_image_num_channels(img)
if fill is None: if fill is None:
fill = 0 fill = 0
if isinstance(fill, (int, float)) and num_bands > 1: if isinstance(fill, (int, float)) and num_bands > 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment