Unverified Commit dbbc5c8e authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added dict support for fill arg for remaining transforms (#6599)

* Updated fill arg typehint for affine, perspective and elastic ops

* Updated pad op on prototype side

* Code updates

* Few other minor updates

* WIP

* WIP

* Updates

* Update _image.py

* Fixed tests
parent 2718f734
...@@ -391,7 +391,7 @@ class TestPad: ...@@ -391,7 +391,7 @@ class TestPad:
if isinstance(fill, int): if isinstance(fill, int):
calls = [ calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"), mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=0, padding_mode="constant"), mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
] ]
else: else:
calls = [ calls = [
...@@ -467,7 +467,7 @@ class TestRandomZoomOut: ...@@ -467,7 +467,7 @@ class TestRandomZoomOut:
if isinstance(fill, int): if isinstance(fill, int):
calls = [ calls = [
mocker.call(image, **params, fill=fill), mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=0), mocker.call(mask, **params, fill=fill),
] ]
else: else:
calls = [ calls = [
...@@ -1555,7 +1555,7 @@ class TestFixedSizeCrop: ...@@ -1555,7 +1555,7 @@ class TestFixedSizeCrop:
@pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2))) @pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2)))
def test__transform(self, mocker, needs): def test__transform(self, mocker, needs):
fill_sentinel = mocker.MagicMock() fill_sentinel = 12
padding_mode_sentinel = mocker.MagicMock() padding_mode_sentinel = mocker.MagicMock()
transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
......
...@@ -195,9 +195,16 @@ def pad_image_tensor(): ...@@ -195,9 +195,16 @@ def pad_image_tensor():
for image, padding, fill, padding_mode in itertools.product( for image, padding, fill, padding_mode in itertools.product(
make_images(), make_images(),
[[1], [1, 1], [1, 1, 2, 2]], # padding [[1], [1, 1], [1, 1, 2, 2]], # padding
[None, 12, 12.0], # fill [None, 128.0, 128, [12.0], [12.0, 13.0, 14.0]], # fill
["constant", "symmetric", "edge", "reflect"], # padding mode, ["constant", "symmetric", "edge", "reflect"], # padding mode,
): ):
if padding_mode != "constant" and fill is not None:
# ValueError: Padding mode 'reflect' is not supported if fill is not scalar
continue
if isinstance(fill, list) and len(fill) != image.shape[-3]:
continue
yield ArgsKwargs(image, padding=padding, fill=fill, padding_mode=padding_mode) yield ArgsKwargs(image, padding=padding, fill=fill, padding_mode=padding_mode)
......
...@@ -211,10 +211,12 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None: ...@@ -211,10 +211,12 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]: def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
_check_fill_arg(fill)
if isinstance(fill, dict): if isinstance(fill, dict):
return fill return fill
else:
return defaultdict(lambda: fill, {features.Mask: 0}) # type: ignore[arg-type, return-value] return defaultdict(lambda: fill) # type: ignore[arg-type, return-value]
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
...@@ -242,7 +244,6 @@ class Pad(Transform): ...@@ -242,7 +244,6 @@ class Pad(Transform):
super().__init__() super().__init__()
_check_padding_arg(padding) _check_padding_arg(padding)
_check_fill_arg(fill)
_check_padding_mode_arg(padding_mode) _check_padding_mode_arg(padding_mode)
self.padding = padding self.padding = padding
...@@ -263,7 +264,6 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -263,7 +264,6 @@ class RandomZoomOut(_RandomApplyTransform):
) -> None: ) -> None:
super().__init__(p=p) super().__init__(p=p)
_check_fill_arg(fill)
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
_check_sequence_input(side_range, "side_range", req_sizes=(2,)) _check_sequence_input(side_range, "side_range", req_sizes=(2,))
...@@ -299,7 +299,7 @@ class RandomRotation(Transform): ...@@ -299,7 +299,7 @@ class RandomRotation(Transform):
degrees: Union[numbers.Number, Sequence], degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Union[FillType, Dict[Type, FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -307,9 +307,7 @@ class RandomRotation(Transform): ...@@ -307,9 +307,7 @@ class RandomRotation(Transform):
self.interpolation = interpolation self.interpolation = interpolation
self.expand = expand self.expand = expand
_check_fill_arg(fill) self.fill = _setup_fill_arg(fill)
self.fill = fill
if center is not None: if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,)) _check_sequence_input(center, "center", req_sizes=(2,))
...@@ -321,12 +319,13 @@ class RandomRotation(Transform): ...@@ -321,12 +319,13 @@ class RandomRotation(Transform):
return dict(angle=angle) return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.rotate( return F.rotate(
inpt, inpt,
**params, **params,
interpolation=self.interpolation, interpolation=self.interpolation,
expand=self.expand, expand=self.expand,
fill=self.fill, fill=fill,
center=self.center, center=self.center,
) )
...@@ -339,7 +338,7 @@ class RandomAffine(Transform): ...@@ -339,7 +338,7 @@ class RandomAffine(Transform):
scale: Optional[Sequence[float]] = None, scale: Optional[Sequence[float]] = None,
shear: Optional[Union[float, Sequence[float]]] = None, shear: Optional[Union[float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Union[FillType, Dict[Type, FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -363,10 +362,7 @@ class RandomAffine(Transform): ...@@ -363,10 +362,7 @@ class RandomAffine(Transform):
self.shear = shear self.shear = shear
self.interpolation = interpolation self.interpolation = interpolation
self.fill = _setup_fill_arg(fill)
_check_fill_arg(fill)
self.fill = fill
if center is not None: if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,)) _check_sequence_input(center, "center", req_sizes=(2,))
...@@ -404,11 +400,12 @@ class RandomAffine(Transform): ...@@ -404,11 +400,12 @@ class RandomAffine(Transform):
return dict(angle=angle, translate=translate, scale=scale, shear=shear) return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.affine( return F.affine(
inpt, inpt,
**params, **params,
interpolation=self.interpolation, interpolation=self.interpolation,
fill=self.fill, fill=fill,
center=self.center, center=self.center,
) )
...@@ -419,7 +416,7 @@ class RandomCrop(Transform): ...@@ -419,7 +416,7 @@ class RandomCrop(Transform):
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None, padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False, pad_if_needed: bool = False,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Union[FillType, Dict[Type, FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -429,12 +426,11 @@ class RandomCrop(Transform): ...@@ -429,12 +426,11 @@ class RandomCrop(Transform):
if pad_if_needed or padding is not None: if pad_if_needed or padding is not None:
if padding is not None: if padding is not None:
_check_padding_arg(padding) _check_padding_arg(padding)
_check_fill_arg(fill)
_check_padding_mode_arg(padding_mode) _check_padding_mode_arg(padding_mode)
self.padding = padding self.padding = padding
self.pad_if_needed = pad_if_needed self.pad_if_needed = pad_if_needed
self.fill = fill self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
...@@ -483,17 +479,18 @@ class RandomCrop(Transform): ...@@ -483,17 +479,18 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: (PERF) check for speed optimization if we avoid repeated pad calls # TODO: (PERF) check for speed optimization if we avoid repeated pad calls
fill = self.fill[type(inpt)]
if self.padding is not None: if self.padding is not None:
inpt = F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode) inpt = F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
if self.pad_if_needed: if self.pad_if_needed:
input_width, input_height = params["input_width"], params["input_height"] input_width, input_height = params["input_width"], params["input_height"]
if input_width < self.size[1]: if input_width < self.size[1]:
padding = [self.size[1] - input_width, 0] padding = [self.size[1] - input_width, 0]
inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode) inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
if input_height < self.size[0]: if input_height < self.size[0]:
padding = [0, self.size[0] - input_height] padding = [0, self.size[0] - input_height]
inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode) inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
return F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) return F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
...@@ -502,19 +499,18 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -502,19 +499,18 @@ class RandomPerspective(_RandomApplyTransform):
def __init__( def __init__(
self, self,
distortion_scale: float = 0.5, distortion_scale: float = 0.5,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Union[FillType, Dict[Type, FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
super().__init__(p=p) super().__init__(p=p)
_check_fill_arg(fill)
if not (0 <= distortion_scale <= 1): if not (0 <= distortion_scale <= 1):
raise ValueError("Argument distortion_scale value should be between 0 and 1") raise ValueError("Argument distortion_scale value should be between 0 and 1")
self.distortion_scale = distortion_scale self.distortion_scale = distortion_scale
self.interpolation = interpolation self.interpolation = interpolation
self.fill = fill self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size # Get image size
...@@ -546,10 +542,11 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -546,10 +542,11 @@ class RandomPerspective(_RandomApplyTransform):
return dict(startpoints=startpoints, endpoints=endpoints) return dict(startpoints=startpoints, endpoints=endpoints)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.perspective( return F.perspective(
inpt, inpt,
**params, **params,
fill=self.fill, fill=fill,
interpolation=self.interpolation, interpolation=self.interpolation,
) )
...@@ -576,17 +573,15 @@ class ElasticTransform(Transform): ...@@ -576,17 +573,15 @@ class ElasticTransform(Transform):
self, self,
alpha: Union[float, Sequence[float]] = 50.0, alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0, sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Union[FillType, Dict[Type, FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2) self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
self.sigma = _setup_float_or_seq(sigma, "sigma", 2) self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
_check_fill_arg(fill)
self.interpolation = interpolation self.interpolation = interpolation
self.fill = fill self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size # Get image size
...@@ -614,10 +609,11 @@ class ElasticTransform(Transform): ...@@ -614,10 +609,11 @@ class ElasticTransform(Transform):
return dict(displacement=displacement) return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.elastic( return F.elastic(
inpt, inpt,
**params, **params,
fill=self.fill, fill=fill,
interpolation=self.interpolation, interpolation=self.interpolation,
) )
...@@ -789,14 +785,16 @@ class FixedSizeCrop(Transform): ...@@ -789,14 +785,16 @@ class FixedSizeCrop(Transform):
def __init__( def __init__(
self, self,
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Union[FillType, Dict[Type, FillType]] = 0,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0] self.crop_height = size[0]
self.crop_width = size[1] self.crop_width = size[1]
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
...@@ -869,7 +867,8 @@ class FixedSizeCrop(Transform): ...@@ -869,7 +867,8 @@ class FixedSizeCrop(Transform):
) )
if params["needs_pad"]: if params["needs_pad"]:
inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode) fill = self.fill[type(inpt)]
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt return inpt
......
...@@ -499,7 +499,7 @@ def _assert_grid_transform_inputs( ...@@ -499,7 +499,7 @@ def _assert_grid_transform_inputs(
# Check fill # Check fill
num_channels = get_dimensions(img)[0] num_channels = get_dimensions(img)[0]
if fill is not None and isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
msg = ( msg = (
"The number of elements in 'fill' cannot broadcast to match the number of " "The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})" "channels of the image ({} != {})"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment