Unverified Commit a5536de9 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added antialias arg to resized crop transform and op (#6193)

parent 11caf37a
...@@ -447,10 +447,17 @@ class TestResize: ...@@ -447,10 +447,17 @@ class TestResize:
], ],
) )
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC]) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
def test_resized_crop(self, scale, ratio, size, interpolation, device): @pytest.mark.parametrize("antialias", [None, True, False])
def test_resized_crop(self, scale, ratio, size, interpolation, antialias, device):
if antialias and interpolation == NEAREST:
pytest.skip("Can not resize if interpolation mode is NEAREST and antialias=True")
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
transform = T.RandomResizedCrop(size=size, scale=scale, ratio=ratio, interpolation=interpolation) transform = T.RandomResizedCrop(
size=size, scale=scale, ratio=ratio, interpolation=interpolation, antialias=antialias
)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
_test_transform_vs_scripted(transform, s_transform, tensor) _test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
......
...@@ -555,6 +555,7 @@ def resized_crop( ...@@ -555,6 +555,7 @@ def resized_crop(
width: int, width: int,
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
) -> Tensor: ) -> Tensor:
"""Crop the given image and resize it to desired size. """Crop the given image and resize it to desired size.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
...@@ -575,13 +576,17 @@ def resized_crop( ...@@ -575,13 +576,17 @@ def resized_crop(
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes.
This can help making the output for PIL images and tensors closer.
Returns: Returns:
PIL Image or Tensor: Cropped image. PIL Image or Tensor: Cropped image.
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(resized_crop) _log_api_usage_once(resized_crop)
img = crop(img, top, left, height, width) img = crop(img, top, left, height, width)
img = resize(img, size, interpolation) img = resize(img, size, interpolation, antialias=antialias)
return img return img
......
...@@ -310,12 +310,8 @@ class Resize(torch.nn.Module): ...@@ -310,12 +310,8 @@ class Resize(torch.nn.Module):
mode). mode).
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
``InterpolationMode.BILINEAR`` only mode. This can help making the output for PIL images and tensors ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes.
closer. This can help making the output for PIL images and tensors closer.
.. warning::
There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor.
""" """
def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None): def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None):
...@@ -873,9 +869,20 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -873,9 +869,20 @@ class RandomResizedCrop(torch.nn.Module):
``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes.
This can help making the output for PIL images and tensors closer.
""" """
def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation=InterpolationMode.BILINEAR): def __init__(
self,
size,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation=InterpolationMode.BILINEAR,
antialias: Optional[bool] = None,
):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
...@@ -896,6 +903,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -896,6 +903,7 @@ class RandomResizedCrop(torch.nn.Module):
interpolation = _interpolation_modes_from_int(interpolation) interpolation = _interpolation_modes_from_int(interpolation)
self.interpolation = interpolation self.interpolation = interpolation
self.antialias = antialias
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
...@@ -952,7 +960,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -952,7 +960,7 @@ class RandomResizedCrop(torch.nn.Module):
PIL Image or Tensor: Randomly cropped and resized image. PIL Image or Tensor: Randomly cropped and resized image.
""" """
i, j, h, w = self.get_params(img, self.scale, self.ratio) i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
def __repr__(self) -> str: def __repr__(self) -> str:
interpolate_str = self.interpolation.value interpolate_str = self.interpolation.value
...@@ -960,6 +968,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -960,6 +968,7 @@ class RandomResizedCrop(torch.nn.Module):
format_string += f", scale={tuple(round(s, 4) for s in self.scale)}" format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}" format_string += f", ratio={tuple(round(r, 4) for r in self.ratio)}"
format_string += f", interpolation={interpolate_str})" format_string += f", interpolation={interpolate_str})"
format_string += f", antialias={self.antialias})"
return format_string return format_string
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment