Unverified Commit 053e7ebd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port Pad to prototype transforms (#5621)

* port Pad to prototype transforms

* use literal
parent 00c119c8
......@@ -71,6 +71,7 @@ class TestSmoke:
transforms.CenterCrop([16, 16]),
transforms.ConvertImageDtype(),
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
)
def test_common(self, transform, input):
transform(input)
......
......@@ -15,6 +15,7 @@ from ._geometry import (
TenCrop,
BatchMultiCrop,
RandomHorizontalFlip,
Pad,
RandomZoomOut,
)
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
......
import collections.abc
import math
import numbers
import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
......@@ -9,6 +10,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from typing_extensions import Literal
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
......@@ -272,42 +274,31 @@ class BatchMultiCrop(Transform):
return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
class RandomZoomOut(Transform):
class Pad(Transform):
def __init__(
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
self,
padding: Union[int, Sequence[int]],
fill: Union[float, Sequence[float]] = 0.0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if fill is None:
fill = 0.0
self.fill = fill
self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
orig_c, orig_h, orig_w = get_image_dimensions(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)
if not isinstance(fill, (numbers.Number, str, tuple, list)):
raise TypeError("Got inappropriate fill arg")
r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom]
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
fill = self.fill
if not isinstance(fill, collections.abc.Sequence):
fill = [fill] * orig_c
if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
return dict(padding=padding, fill=fill)
self.padding = padding
self.fill = fill
self.padding_mode = padding_mode
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image) or is_simple_tensor(input):
......@@ -349,6 +340,48 @@ class RandomZoomOut(Transform):
else:
return input
class RandomZoomOut(Transform):
def __init__(
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
) -> None:
super().__init__()
if fill is None:
fill = 0.0
self.fill = fill
self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
orig_c, orig_h, orig_w = get_image_dimensions(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)
r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom]
fill = self.fill
if not isinstance(fill, collections.abc.Sequence):
fill = [fill] * orig_c
return dict(padding=padding, fill=fill)
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
transform = Pad(**params, padding_mode="constant")
return transform(input)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment