Unverified Commit 053e7ebd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port Pad to prototype transforms (#5621)

* port Pad to prototype transforms

* use literal
parent 00c119c8
...@@ -71,6 +71,7 @@ class TestSmoke: ...@@ -71,6 +71,7 @@ class TestSmoke:
transforms.CenterCrop([16, 16]), transforms.CenterCrop([16, 16]),
transforms.ConvertImageDtype(), transforms.ConvertImageDtype(),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.Pad(5),
) )
def test_common(self, transform, input): def test_common(self, transform, input):
transform(input) transform(input)
......
...@@ -15,6 +15,7 @@ from ._geometry import ( ...@@ -15,6 +15,7 @@ from ._geometry import (
TenCrop, TenCrop,
BatchMultiCrop, BatchMultiCrop,
RandomHorizontalFlip, RandomHorizontalFlip,
Pad,
RandomZoomOut, RandomZoomOut,
) )
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
......
import collections.abc import collections.abc
import math import math
import numbers
import warnings import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast from typing import Any, Dict, List, Union, Sequence, Tuple, cast
...@@ -9,6 +10,7 @@ from torchvision.prototype import features ...@@ -9,6 +10,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.functional import pil_to_tensor from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from typing_extensions import Literal
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
...@@ -272,42 +274,31 @@ class BatchMultiCrop(Transform): ...@@ -272,42 +274,31 @@ class BatchMultiCrop(Transform):
return apply_recursively(inputs if len(inputs) > 1 else inputs[0]) return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
class RandomZoomOut(Transform): class Pad(Transform):
def __init__( def __init__(
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 self,
padding: Union[int, Sequence[int]],
fill: Union[float, Sequence[float]] = 0.0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if fill is None: if not isinstance(fill, (numbers.Number, str, tuple, list)):
fill = 0.0 raise TypeError("Got inappropriate fill arg")
self.fill = fill
self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
orig_c, orig_h, orig_w = get_image_dimensions(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)
r = torch.rand(2) if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
left = int((canvas_width - orig_w) * r[0]) raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom]
fill = self.fill if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
if not isinstance(fill, collections.abc.Sequence): raise ValueError(
fill = [fill] * orig_c f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
return dict(padding=padding, fill=fill) self.padding = padding
self.fill = fill
self.padding_mode = padding_mode
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image) or is_simple_tensor(input): if isinstance(input, features.Image) or is_simple_tensor(input):
...@@ -349,6 +340,48 @@ class RandomZoomOut(Transform): ...@@ -349,6 +340,48 @@ class RandomZoomOut(Transform):
else: else:
return input return input
class RandomZoomOut(Transform):
def __init__(
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
) -> None:
super().__init__()
if fill is None:
fill = 0.0
self.fill = fill
self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
orig_c, orig_h, orig_w = get_image_dimensions(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)
r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom]
fill = self.fill
if not isinstance(fill, collections.abc.Sequence):
fill = [fill] * orig_c
return dict(padding=padding, fill=fill)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
transform = Pad(**params, padding_mode="constant")
return transform(input)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p: if torch.rand(1) >= self.p:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment