Unverified Commit bdc55567 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

introduce nearest-exact interpolation (#6754)

* introduce nearest-exact interpolation

* update prototype tests

* update stable tests
parent 3eafe77a
...@@ -232,6 +232,7 @@ def reference_inputs_resize_image_tensor(): ...@@ -232,6 +232,7 @@ def reference_inputs_resize_image_tensor():
make_image_loaders(extra_dims=[()]), make_image_loaders(extra_dims=[()]),
[ [
F.InterpolationMode.NEAREST, F.InterpolationMode.NEAREST,
F.InterpolationMode.NEAREST_EXACT,
F.InterpolationMode.BILINEAR, F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC, F.InterpolationMode.BICUBIC,
], ],
...@@ -881,6 +882,7 @@ def reference_inputs_resized_crop_image_tensor(): ...@@ -881,6 +882,7 @@ def reference_inputs_resized_crop_image_tensor():
make_image_loaders(extra_dims=[()]), make_image_loaders(extra_dims=[()]),
[ [
F.InterpolationMode.NEAREST, F.InterpolationMode.NEAREST,
F.InterpolationMode.NEAREST_EXACT,
F.InterpolationMode.BILINEAR, F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC, F.InterpolationMode.BICUBIC,
], ],
......
...@@ -25,7 +25,12 @@ from common_utils import ( ...@@ -25,7 +25,12 @@ from common_utils import (
) )
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC = (
InterpolationMode.NEAREST,
InterpolationMode.NEAREST_EXACT,
InterpolationMode.BILINEAR,
InterpolationMode.BICUBIC,
)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -506,7 +511,7 @@ def test_perspective_interpolation_warning(): ...@@ -506,7 +511,7 @@ def test_perspective_interpolation_warning():
], ],
) )
@pytest.mark.parametrize("max_size", [None, 34, 40, 1000]) @pytest.mark.parametrize("max_size", [None, 34, 40, 1000])
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST]) @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST, NEAREST_EXACT])
def test_resize(device, dt, size, max_size, interpolation): def test_resize(device, dt, size, max_size, interpolation):
if dt == torch.float16 and device == "cpu": if dt == torch.float16 and device == "cpu":
...@@ -966,7 +971,7 @@ def test_pad(device, dt, pad, config): ...@@ -966,7 +971,7 @@ def test_pad(device, dt, pad, config):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("mode", [NEAREST, BILINEAR, BICUBIC]) @pytest.mark.parametrize("mode", [NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC])
def test_resized_crop(device, mode): def test_resized_crop(device, mode):
# test values of F.resized_crop in several cases: # test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity # 1) resize to the same size, crop to the same size => should be identity
......
...@@ -20,7 +20,12 @@ from torchvision import transforms as T ...@@ -20,7 +20,12 @@ from torchvision import transforms as T
from torchvision.transforms import functional as F, InterpolationMode from torchvision.transforms import functional as F, InterpolationMode
from torchvision.transforms.autoaugment import _apply_op from torchvision.transforms.autoaugment import _apply_op
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, NEAREST_EXACT, BILINEAR, BICUBIC = (
InterpolationMode.NEAREST,
InterpolationMode.NEAREST_EXACT,
InterpolationMode.BILINEAR,
InterpolationMode.BICUBIC,
)
def _test_transform_vs_scripted(transform, s_transform, tensor, msg=None): def _test_transform_vs_scripted(transform, s_transform, tensor, msg=None):
...@@ -378,7 +383,7 @@ class TestResize: ...@@ -378,7 +383,7 @@ class TestResize:
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64])
@pytest.mark.parametrize("size", [[32], [32, 32], (32, 32), [34, 35]]) @pytest.mark.parametrize("size", [[32], [32, 32], (32, 32), [34, 35]])
@pytest.mark.parametrize("max_size", [None, 35, 1000]) @pytest.mark.parametrize("max_size", [None, 35, 1000])
@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST]) @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST, NEAREST_EXACT])
def test_resize_scripted(self, dt, size, max_size, interpolation, device): def test_resize_scripted(self, dt, size, max_size, interpolation, device):
tensor, _ = _create_data(height=34, width=36, device=device) tensor, _ = _create_data(height=34, width=36, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
...@@ -402,12 +407,12 @@ class TestResize: ...@@ -402,12 +407,12 @@ class TestResize:
@pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]]) @pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]])
@pytest.mark.parametrize("ratio", [(0.75, 1.333), [0.75, 1.333]]) @pytest.mark.parametrize("ratio", [(0.75, 1.333), [0.75, 1.333]])
@pytest.mark.parametrize("size", [(32,), [44], [32], [32, 32], (32, 32), [44, 55]]) @pytest.mark.parametrize("size", [(32,), [44], [32], [32, 32], (32, 32), [44, 55]])
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC]) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC, NEAREST_EXACT])
@pytest.mark.parametrize("antialias", [None, True, False]) @pytest.mark.parametrize("antialias", [None, True, False])
def test_resized_crop(self, scale, ratio, size, interpolation, antialias, device): def test_resized_crop(self, scale, ratio, size, interpolation, antialias, device):
if antialias and interpolation == NEAREST: if antialias and interpolation in {NEAREST, NEAREST_EXACT}:
pytest.skip("Can not resize if interpolation mode is NEAREST and antialias=True") pytest.skip(f"Can not resize if interpolation mode is {interpolation} and antialias=True")
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
......
...@@ -20,10 +20,12 @@ from . import functional_pil as F_pil, functional_tensor as F_t ...@@ -20,10 +20,12 @@ from . import functional_pil as F_pil, functional_tensor as F_t
class InterpolationMode(Enum): class InterpolationMode(Enum):
"""Interpolation modes """Interpolation modes
Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``. Available interpolation methods are ``nearest``, ``nearest-exact``, ``bilinear``, ``bicubic``, ``box``, ``hamming``,
and ``lanczos``.
""" """
NEAREST = "nearest" NEAREST = "nearest"
NEAREST_EXACT = "nearest-exact"
BILINEAR = "bilinear" BILINEAR = "bilinear"
BICUBIC = "bicubic" BICUBIC = "bicubic"
# For PIL compatibility # For PIL compatibility
...@@ -50,6 +52,7 @@ pil_modes_mapping = { ...@@ -50,6 +52,7 @@ pil_modes_mapping = {
InterpolationMode.NEAREST: 0, InterpolationMode.NEAREST: 0,
InterpolationMode.BILINEAR: 2, InterpolationMode.BILINEAR: 2,
InterpolationMode.BICUBIC: 3, InterpolationMode.BICUBIC: 3,
InterpolationMode.NEAREST_EXACT: 0,
InterpolationMode.BOX: 4, InterpolationMode.BOX: 4,
InterpolationMode.HAMMING: 5, InterpolationMode.HAMMING: 5,
InterpolationMode.LANCZOS: 1, InterpolationMode.LANCZOS: 1,
...@@ -416,7 +419,8 @@ def resize( ...@@ -416,7 +419,8 @@ def resize(
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. :class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
max_size (int, optional): The maximum allowed for the longer edge of max_size (int, optional): The maximum allowed for the longer edge of
...@@ -617,7 +621,8 @@ def resized_crop( ...@@ -617,7 +621,8 @@ def resized_crop(
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. :class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
......
...@@ -296,8 +296,8 @@ class Resize(torch.nn.Module): ...@@ -296,8 +296,8 @@ class Resize(torch.nn.Module):
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
max_size (int, optional): The maximum allowed for the longer edge of max_size (int, optional): The maximum allowed for the longer edge of
...@@ -865,8 +865,8 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -865,8 +865,8 @@ class RandomResizedCrop(torch.nn.Module):
resizing. resizing.
interpolation (InterpolationMode): Desired interpolation enum defined by interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment