Unverified Commit e0e95380 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Argument fill can accept dict of base types (#6586)


Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent e6c0048c
...@@ -378,6 +378,28 @@ class TestPad: ...@@ -378,6 +378,28 @@ class TestPad:
fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
def test__transform_image_mask(self, fill, mocker):
transform = transforms.Pad(1, fill=fill, padding_mode="constant")
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
image = features.Image(torch.rand(3, 32, 32))
mask = features.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]
_ = transform(inpt)
if isinstance(fill, int):
calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=0, padding_mode="constant"),
]
else:
calls = [
mocker.call(image, padding=1, fill=fill[type(image)], padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill[type(mask)], padding_mode="constant"),
]
fn.assert_has_calls(calls)
class TestRandomZoomOut: class TestRandomZoomOut:
def test_assertions(self): def test_assertions(self):
...@@ -400,7 +422,6 @@ class TestRandomZoomOut: ...@@ -400,7 +422,6 @@ class TestRandomZoomOut:
params = transform._get_params(image) params = transform._get_params(image)
assert params["fill"] == fill
assert len(params["padding"]) == 4 assert len(params["padding"]) == 4
assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h
...@@ -426,7 +447,34 @@ class TestRandomZoomOut: ...@@ -426,7 +447,34 @@ class TestRandomZoomOut:
torch.rand(1) # random apply changes random state torch.rand(1) # random apply changes random state
params = transform._get_params(inpt) params = transform._get_params(inpt)
fn.assert_called_once_with(inpt, **params) fn.assert_called_once_with(inpt, **params, fill=fill)
@pytest.mark.parametrize("fill", [12, {features.Image: 12, features.Mask: 34}])
def test__transform_image_mask(self, fill, mocker):
transform = transforms.RandomZoomOut(fill=fill, p=1.0)
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
image = features.Image(torch.rand(3, 32, 32))
mask = features.Mask(torch.randint(0, 5, size=(32, 32)))
inpt = [image, mask]
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)
if isinstance(fill, int):
calls = [
mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=0),
]
else:
calls = [
mocker.call(image, **params, fill=fill[type(image)]),
mocker.call(mask, **params, fill=fill[type(mask)]),
]
fn.assert_has_calls(calls)
class TestRandomRotation: class TestRandomRotation:
......
...@@ -58,7 +58,14 @@ class Mask(_Feature): ...@@ -58,7 +58,14 @@ class Mask(_Feature):
if not isinstance(padding, int): if not isinstance(padding, int):
padding = list(padding) padding = list(padding)
output = self._F.pad_mask(self, padding, padding_mode=padding_mode) if isinstance(fill, (int, float)) or fill is None:
if fill is None:
fill = 0
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
else:
# Let's raise an error for vector fill on masks
raise ValueError("Non-scalar fill value is not supported")
return Mask.new_like(self, output) return Mask.new_like(self, output)
def rotate( def rotate(
......
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Union from collections import defaultdict
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -16,6 +17,7 @@ from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, h ...@@ -16,6 +17,7 @@ from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, h
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
FillType = Union[int, float, Sequence[int], Sequence[float]]
class RandomHorizontalFlip(_RandomApplyTransform): class RandomHorizontalFlip(_RandomApplyTransform):
...@@ -196,11 +198,23 @@ class TenCrop(Transform): ...@@ -196,11 +198,23 @@ class TenCrop(Transform):
return super().forward(*inputs) return super().forward(*inputs)
def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> None: def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
else:
if not isinstance(fill, (numbers.Number, tuple, list)): if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg") raise TypeError("Got inappropriate fill arg")
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
if isinstance(fill, dict):
return fill
else:
return defaultdict(lambda: fill, {features.Mask: 0}) # type: ignore[arg-type, return-value]
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)): if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg") raise TypeError("Got inappropriate padding arg")
...@@ -220,7 +234,7 @@ class Pad(Transform): ...@@ -220,7 +234,7 @@ class Pad(Transform):
def __init__( def __init__(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Union[FillType, Dict[Type, FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -230,24 +244,25 @@ class Pad(Transform): ...@@ -230,24 +244,25 @@ class Pad(Transform):
_check_padding_mode_arg(padding_mode) _check_padding_mode_arg(padding_mode)
self.padding = padding self.padding = padding
self.fill = fill self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode) fill = self.fill[type(inpt)]
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
class RandomZoomOut(_RandomApplyTransform): class RandomZoomOut(_RandomApplyTransform):
def __init__( def __init__(
self, self,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Union[FillType, Dict[Type, FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0), side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
super().__init__(p=p) super().__init__(p=p)
_check_fill_arg(fill) _check_fill_arg(fill)
self.fill = fill self.fill = _setup_fill_arg(fill)
_check_sequence_input(side_range, "side_range", req_sizes=(2,)) _check_sequence_input(side_range, "side_range", req_sizes=(2,))
...@@ -256,7 +271,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -256,7 +271,7 @@ class RandomZoomOut(_RandomApplyTransform):
raise ValueError(f"Invalid canvas side range provided {side_range}.") raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
orig_c, orig_h, orig_w = query_chw(sample) _, orig_h, orig_w = query_chw(sample)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r) canvas_width = int(orig_w * r)
...@@ -269,10 +284,11 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -269,10 +284,11 @@ class RandomZoomOut(_RandomApplyTransform):
bottom = canvas_height - (top + orig_h) bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom] padding = [left, top, right, bottom]
return dict(padding=padding, fill=self.fill) return dict(padding=padding)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.pad(inpt, **params) fill = self.fill[type(inpt)]
return F.pad(inpt, **params, fill=fill)
class RandomRotation(Transform): class RandomRotation(Transform):
......
...@@ -635,14 +635,19 @@ def _pad_with_vector_fill( ...@@ -635,14 +635,19 @@ def _pad_with_vector_fill(
return output return output
def pad_mask(mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant") -> torch.Tensor: def pad_mask(
mask: torch.Tensor,
padding: Union[int, List[int]],
padding_mode: str = "constant",
fill: Optional[Union[int, float]] = 0,
) -> torch.Tensor:
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
needs_squeeze = True needs_squeeze = True
else: else:
needs_squeeze = False needs_squeeze = False
output = pad_image_tensor(img=mask, padding=padding, fill=0, padding_mode=padding_mode) output = pad_image_tensor(img=mask, padding=padding, fill=fill, padding_mode=padding_mode)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment