Unverified Commit 6908129a authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fixed issue with `F.pad` from RandomZoomOut (#6386)

* [proto] Fixed issue with `F.pad` from RandomZoomOut

* Fixed failing tests

* Fixed wrong type hint

* Fixed fill=None in pad_image_pil

* Try to support fill=None in functional

* Code formatting
parent 2e70ee1a
......@@ -377,12 +377,11 @@ class TestRandomZoomOut:
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range)
image = mocker.MagicMock(spec=features.Image)
c = image.num_channels = 3
h, w = image.image_size = (24, 32)
params = transform._get_params(image)
assert params["fill"] == (fill if not isinstance(fill, int) else [fill] * c)
assert params["fill"] == fill
assert len(params["padding"]) == 4
assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h
......
......@@ -464,7 +464,7 @@ def pad_image_tensor():
for image, padding, fill, padding_mode in itertools.product(
make_images(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
[12, 12.0], # fill
[None, 12, 12.0], # fill
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
yield SampleInput(image, padding=padding, fill=fill, padding_mode=padding_mode)
......
......@@ -174,11 +174,8 @@ class Image(_Feature):
if not isinstance(padding, int):
padding = list(padding)
if fill is None:
fill = 0
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
if isinstance(fill, (int, float)):
if isinstance(fill, (int, float)) or fill is None:
output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
else:
from torchvision.prototype.transforms.functional._geometry import _pad_with_vector_fill
......
......@@ -294,12 +294,7 @@ class RandomZoomOut(_RandomApplyTransform):
bottom = canvas_height - (top + orig_h)
padding = [left, top, right, bottom]
# vfdev-5: Can we put that into pad_image_tensor ?
fill = self.fill
if not isinstance(fill, collections.abc.Sequence):
fill = [fill] * orig_c
return dict(padding=padding, fill=fill)
return dict(padding=padding, fill=self.fill)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.pad(inpt, **params)
......
......@@ -531,7 +531,10 @@ pad_image_pil = _FP.pad
def pad_image_tensor(
img: torch.Tensor, padding: Union[int, List[int]], fill: Union[int, float] = 0, padding_mode: str = "constant"
img: torch.Tensor,
padding: Union[int, List[int]],
fill: Optional[Union[int, float]] = 0,
padding_mode: str = "constant",
) -> torch.Tensor:
num_channels, height, width = img.shape[-3:]
extra_dims = img.shape[:-3]
......@@ -555,7 +558,7 @@ def _pad_with_vector_fill(
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
output = pad_image_tensor(img, padding, fill=0, padding_mode="constant")
left, top, right, bottom = _FT._parse_pad_padding(padding)
left, right, top, bottom = _FT._parse_pad_padding(padding)
fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1)
if top > 0:
......@@ -614,11 +617,8 @@ def pad(
if not isinstance(padding, int):
padding = list(padding)
if fill is None:
fill = 0
# TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
if isinstance(fill, (int, float)):
if isinstance(fill, (int, float)) or fill is None:
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode)
......
......@@ -155,7 +155,7 @@ def pad(
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, tuple, list)):
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")
......
......@@ -371,10 +371,13 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
def pad(
img: Tensor, padding: Union[int, List[int]], fill: Union[int, float] = 0, padding_mode: str = "constant"
img: Tensor, padding: Union[int, List[int]], fill: Optional[Union[int, float]] = 0, padding_mode: str = "constant"
) -> Tensor:
_assert_image_tensor(img)
if fill is None:
fill = 0
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (int, float)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment