Unverified Commit e0e95380 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Argument fill can accept dict of base types (#6586)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent e6c0048c
......@@ -378,6 +378,28 @@ class TestPad:
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
def test__transform_image_mask(self, fill, mocker):
transform = transforms.Pad(1, fill=fill, padding_mode="constant")
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
image = features.Image(torch.rand(3, 32, 32))
mask = features.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]
_ = transform(inpt)
if isinstance(fill, int):
calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=0, padding_mode="constant"),
]
else:
calls = [
mocker.call(image, padding=1, fill=fill[type(image)], padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill[type(mask)], padding_mode="constant"),
]
fn.assert_has_calls(calls)
class TestRandomZoomOut:
def test_assertions(self):
......@@ -400,7 +422,6 @@ class TestRandomZoomOut:
params = transform._get_params(image)
assert params["fill"] == fill
assert len(params["padding"]) == 4
assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h
......@@ -426,7 +447,34 @@ class TestRandomZoomOut:
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)
fn.assert_called_once_with(inpt, **params)
fn.assert_called_once_with(inpt, **params, fill=fill)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
def test__transform_image_mask(self, fill, mocker):
transform = transforms.RandomZoomOut(fill=fill, p=1.0)
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
image = features.Image(torch.rand(3, 32, 32))
mask = features.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)
if isinstance(fill, int):
calls = [
mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=0),
]
else:
calls = [
mocker.call(image, **params, fill=fill[type(image)]),
mocker.call(mask, **params, fill=fill[type(mask)]),
]
fn.assert_has_calls(calls)
class TestRandomRotation:
......
......@@ -58,7 +58,14 @@ class Mask(_Feature):
if not isinstance(padding, int):
padding = list(padding)
output = self._F.pad_mask(self, padding, padding_mode=padding_mode)
if isinstance(fill, (int, float)) or fill is None:
if fill is None:
fill = 0
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
else:
# Let's raise an error for vector fill on masks
raise ValueError("Non-scalar fill value is not supported")
return Mask.new_like(self, output)
def rotate(
......
import math
import numbers
import warnings
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union
from collections import defaultdict
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image
import torch
......@@ -16,6 +17,7 @@ from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, h
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
FillType = Union[int, float, Sequence[int], Sequence[float]]
class RandomHorizontalFlip(_RandomApplyTransform):
......@@ -196,11 +198,23 @@ class TenCrop(Transform):
return super().forward(*inputs)
def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None:
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
else:
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
if isinstance(fill, dict):
return fill
else:
return defaultdict(lambda: fill, {features.Mask: 0}) # type: ignore[arg-type, return-value]
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
......@@ -220,7 +234,7 @@ class Pad(Transform):
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
......@@ -230,24 +244,25 @@ class Pad(Transform):
_check_padding_mode_arg(padding_mode)
self.padding = padding
self.fill = fill
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)
fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
class RandomZoomOut(_RandomApplyTransform):
def __init__(
self,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
super().__init__(p=p)
_check_fill_arg(fill)
self.fill = fill
self.fill = _setup_fill_arg(fill)
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
......@@ -256,7 +271,7 @@ class RandomZoomOut(_RandomApplyTransform):
raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, sample: Any) -> Dict[str, Any]:
orig_c, orig_h, orig_w = query_chw(sample)
_, orig_h, orig_w = query_chw(sample)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
......@@ -269,10 +284,11 @@ class RandomZoomOut(_RandomApplyTransform):
bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom]
return dict(padding=padding, fill=self.fill)
return dict(padding=padding)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.pad(inpt, **params)
fill = self.fill[type(inpt)]
return F.pad(inpt, **params, fill=fill)
class RandomRotation(Transform):
......
......@@ -635,14 +635,19 @@ def _pad_with_vector_fill(
return output
def pad_mask(mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant") -> torch.Tensor:
def pad_mask(
mask: torch.Tensor,
padding: Union[int, List[int]],
padding_mode: str = "constant",
fill: Optional[Union[int, float]] = 0,
) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = pad_image_tensor(img=mask, padding=padding, fill=0, padding_mode=padding_mode)
output = pad_image_tensor(img=mask, padding=padding, fill=fill, padding_mode=padding_mode)
if needs_squeeze:
output = output.squeeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment