Unverified Commit 2718f734 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Updated fill arg typehint for affine, perspective and elastic ops (#6595)

* Updated fill arg typehint for affine, perspective and elastic ops

* Updated pad op on prototype side

* Code updates

* Few other minor updates
parent 9c660c65
...@@ -226,7 +226,7 @@ def sample_inputs_affine_image_tensor(): ...@@ -226,7 +226,7 @@ def sample_inputs_affine_image_tensor():
], ],
[None, (0, 0)], [None, (0, 0)],
): ):
for fill in [None, [0.5] * image_loader.num_channels]: for fill in [None, 128.0, 128, [12.0], [0.5] * image_loader.num_channels]:
yield ArgsKwargs( yield ArgsKwargs(
image_loader, image_loader,
interpolation=interpolation_mode, interpolation=interpolation_mode,
......
...@@ -228,8 +228,12 @@ def perspective_image_tensor(): ...@@ -228,8 +228,12 @@ def perspective_image_tensor():
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
], ],
[None, [128], [12.0]], # fill [None, 128.0, 128, [12.0], [1.0, 2.0, 3.0]], # fill
): ):
if isinstance(fill, list) and len(fill) == 3 and image.shape[1] != 3:
# skip the test with non-broadcastable fill value
continue
yield ArgsKwargs(image, perspective_coeffs=perspective_coeffs, fill=fill) yield ArgsKwargs(image, perspective_coeffs=perspective_coeffs, fill=fill)
...@@ -268,8 +272,12 @@ def perspective_mask(): ...@@ -268,8 +272,12 @@ def perspective_mask():
def elastic_image_tensor(): def elastic_image_tensor():
for image, fill in itertools.product( for image, fill in itertools.product(
make_images(extra_dims=((), (4,))), make_images(extra_dims=((), (4,))),
[None, [128], [12.0]], # fill [None, 128.0, 128, [12.0], [1.0, 2.0, 3.0]], # fill
): ):
if isinstance(fill, list) and len(fill) == 3 and image.shape[1] != 3:
# skip the test with non-broadcastable fill value
continue
h, w = image.shape[-2:] h, w = image.shape[-2:]
displacement = torch.rand(1, h, w, 2) displacement = torch.rand(1, h, w, 2)
yield ArgsKwargs(image, displacement=displacement, fill=fill) yield ArgsKwargs(image, displacement=displacement, fill=fill)
......
...@@ -177,12 +177,9 @@ class Image(_Feature): ...@@ -177,12 +177,9 @@ class Image(_Feature):
if not isinstance(padding, int): if not isinstance(padding, int):
padding = list(padding) padding = list(padding)
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour fill = self._F._geometry._convert_fill_arg(fill)
if isinstance(fill, (int, float)) or fill is None:
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
else:
output = self._F._geometry._pad_with_vector_fill(self, padding, fill=fill, padding_mode=padding_mode)
output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
return Image.new_like(self, output) return Image.new_like(self, output)
def rotate( def rotate(
......
...@@ -58,14 +58,9 @@ class Mask(_Feature): ...@@ -58,14 +58,9 @@ class Mask(_Feature):
if not isinstance(padding, int): if not isinstance(padding, int):
padding = list(padding) padding = list(padding)
if isinstance(fill, (int, float)) or fill is None: fill = self._F._geometry._convert_fill_arg(fill)
if fill is None:
fill = 0
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
else:
# Let's raise an error for vector fill on masks
raise ValueError("Non-scalar fill value is not supported")
output = self._F.pad_mask(self, padding, padding_mode=padding_mode, fill=fill)
return Mask.new_like(self, output) return Mask.new_like(self, output)
def rotate( def rotate(
......
...@@ -232,7 +232,7 @@ def affine_image_tensor( ...@@ -232,7 +232,7 @@ def affine_image_tensor(
scale: float, scale: float,
shear: List[float], shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, List[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if img.numel() == 0: if img.numel() == 0:
...@@ -405,7 +405,9 @@ def affine_mask( ...@@ -405,7 +405,9 @@ def affine_mask(
return output return output
def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]: def _convert_fill_arg(
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]
) -> Optional[Union[int, float, List[float]]]:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0 # So, we can't reassign fill to 0
# if fill is None: # if fill is None:
...@@ -416,9 +418,6 @@ def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[f ...@@ -416,9 +418,6 @@ def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[f
# This cast does Sequence -> List[float] to please mypy and torch.jit.script # This cast does Sequence -> List[float] to please mypy and torch.jit.script
if not isinstance(fill, (int, float)): if not isinstance(fill, (int, float)):
fill = [float(v) for v in list(fill)] fill = [float(v) for v in list(fill)]
else:
# It is OK to cast int to float as later we use inpt.dtype
fill = [float(fill)]
return fill return fill
...@@ -591,7 +590,23 @@ pad_image_pil = _FP.pad ...@@ -591,7 +590,23 @@ pad_image_pil = _FP.pad
def pad_image_tensor( def pad_image_tensor(
img: torch.Tensor, img: torch.Tensor,
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: Optional[Union[int, float]] = 0, fill: Optional[Union[int, float, List[float]]] = None,
padding_mode: str = "constant",
) -> torch.Tensor:
if fill is None:
# This is a JIT workaround
return _pad_with_scalar_fill(img, padding, fill=None, padding_mode=padding_mode)
elif isinstance(fill, (int, float)) or len(fill) == 1:
fill_number = fill[0] if isinstance(fill, list) else fill
return _pad_with_scalar_fill(img, padding, fill=fill_number, padding_mode=padding_mode)
else:
return _pad_with_vector_fill(img, padding, fill=fill, padding_mode=padding_mode)
def _pad_with_scalar_fill(
img: torch.Tensor,
padding: Union[int, List[int]],
fill: Union[int, float, None],
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
num_channels, height, width = img.shape[-3:] num_channels, height, width = img.shape[-3:]
...@@ -614,13 +629,13 @@ def pad_image_tensor( ...@@ -614,13 +629,13 @@ def pad_image_tensor(
def _pad_with_vector_fill( def _pad_with_vector_fill(
img: torch.Tensor, img: torch.Tensor,
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: Sequence[float] = [0.0], fill: List[float],
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
if padding_mode != "constant": if padding_mode != "constant":
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
output = pad_image_tensor(img, padding, fill=0, padding_mode="constant") output = _pad_with_scalar_fill(img, padding, fill=0, padding_mode="constant")
left, right, top, bottom = _parse_pad_padding(padding) left, right, top, bottom = _parse_pad_padding(padding)
fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1) fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1)
...@@ -639,8 +654,14 @@ def pad_mask( ...@@ -639,8 +654,14 @@ def pad_mask(
mask: torch.Tensor, mask: torch.Tensor,
padding: Union[int, List[int]], padding: Union[int, List[int]],
padding_mode: str = "constant", padding_mode: str = "constant",
fill: Optional[Union[int, float]] = 0, fill: Optional[Union[int, float, List[float]]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if fill is None:
fill = 0
if isinstance(fill, list):
raise ValueError("Non-scalar fill value is not supported")
if mask.ndim < 3: if mask.ndim < 3:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
needs_squeeze = True needs_squeeze = True
...@@ -693,10 +714,9 @@ def pad( ...@@ -693,10 +714,9 @@ def pad(
if not isinstance(padding, int): if not isinstance(padding, int):
padding = list(padding) padding = list(padding)
# TODO: PyTorch's pad supports only scalars on fill. So we need to overwrite the colour fill = _convert_fill_arg(fill)
if isinstance(fill, (int, float)) or fill is None:
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode) return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode)
crop_image_tensor = _FT.crop crop_image_tensor = _FT.crop
...@@ -739,7 +759,7 @@ def perspective_image_tensor( ...@@ -739,7 +759,7 @@ def perspective_image_tensor(
img: torch.Tensor, img: torch.Tensor,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill) return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill)
...@@ -878,7 +898,7 @@ def elastic_image_tensor( ...@@ -878,7 +898,7 @@ def elastic_image_tensor(
img: torch.Tensor, img: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill)
......
...@@ -600,7 +600,10 @@ def _gen_affine_grid( ...@@ -600,7 +600,10 @@ def _gen_affine_grid(
def affine( def affine(
img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None img: Tensor,
matrix: List[float],
interpolation: str = "nearest",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor: ) -> Tensor:
_assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"])
...@@ -693,7 +696,10 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, ...@@ -693,7 +696,10 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
def perspective( def perspective(
img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None img: Tensor,
perspective_coeffs: List[float],
interpolation: str = "bilinear",
fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor: ) -> Tensor:
if not (isinstance(img, torch.Tensor)): if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor.") raise TypeError("Input img should be Tensor.")
...@@ -950,7 +956,7 @@ def elastic_transform( ...@@ -950,7 +956,7 @@ def elastic_transform(
img: Tensor, img: Tensor,
displacement: Tensor, displacement: Tensor,
interpolation: str = "bilinear", interpolation: str = "bilinear",
fill: Optional[List[float]] = None, fill: Optional[Union[int, float, List[float]]] = None,
) -> Tensor: ) -> Tensor:
if not (isinstance(img, torch.Tensor)): if not (isinstance(img, torch.Tensor)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment