Unverified Commit dbbc5c8e authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added dict support for fill arg for remaining transforms (#6599)

* Updated fill arg typehint for affine, perspective and elastic ops

* Updated pad op on prototype side

* Code updates

* Few other minor updates

* WIP

* WIP

* Updates

* Update _image.py

* Fixed tests
parent 2718f734
......@@ -391,7 +391,7 @@ class TestPad:
if isinstance(fill, int):
calls = [
mocker.call(image, padding=1, fill=fill, padding_mode="constant"),
mocker.call(mask, padding=1, fill=0, padding_mode="constant"),
mocker.call(mask, padding=1, fill=fill, padding_mode="constant"),
]
else:
calls = [
......@@ -467,7 +467,7 @@ class TestRandomZoomOut:
if isinstance(fill, int):
calls = [
mocker.call(image, **params, fill=fill),
mocker.call(mask, **params, fill=0),
mocker.call(mask, **params, fill=fill),
]
else:
calls = [
......@@ -1555,7 +1555,7 @@ class TestFixedSizeCrop:
@pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2)))
def test__transform(self, mocker, needs):
fill_sentinel = mocker.MagicMock()
fill_sentinel = 12
padding_mode_sentinel = mocker.MagicMock()
transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
......
......@@ -195,9 +195,16 @@ def pad_image_tensor():
for image, padding, fill, padding_mode in itertools.product(
make_images(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
[None, 12, 12.0], # fill
[None, 128.0, 128, [12.0], [12.0, 13.0, 14.0]], # fill
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
if padding_mode != "constant" and fill is not None:
# ValueError: Padding mode 'reflect' is not supported if fill is not scalar
continue
if isinstance(fill, list) and len(fill) != image.shape[-3]:
continue
yield ArgsKwargs(image, padding=padding, fill=fill, padding_mode=padding_mode)
......
......@@ -211,10 +211,12 @@ def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
_check_fill_arg(fill)
if isinstance(fill, dict):
return fill
else:
return defaultdict(lambda: fill, {features.Mask: 0}) # type: ignore[arg-type, return-value]
return defaultdict(lambda: fill) # type: ignore[arg-type, return-value]
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
......@@ -242,7 +244,6 @@ class Pad(Transform):
super().__init__()
_check_padding_arg(padding)
_check_fill_arg(fill)
_check_padding_mode_arg(padding_mode)
self.padding = padding
......@@ -263,7 +264,6 @@ class RandomZoomOut(_RandomApplyTransform):
) -> None:
super().__init__(p=p)
_check_fill_arg(fill)
self.fill = _setup_fill_arg(fill)
_check_sequence_input(side_range, "side_range", req_sizes=(2,))
......@@ -299,7 +299,7 @@ class RandomRotation(Transform):
degrees: Union[numbers.Number, Sequence],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
......@@ -307,9 +307,7 @@ class RandomRotation(Transform):
self.interpolation = interpolation
self.expand = expand
_check_fill_arg(fill)
self.fill = fill
self.fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
......@@ -321,12 +319,13 @@ class RandomRotation(Transform):
return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.rotate(
inpt,
**params,
interpolation=self.interpolation,
expand=self.expand,
fill=self.fill,
fill=fill,
center=self.center,
)
......@@ -339,7 +338,7 @@ class RandomAffine(Transform):
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[float, Sequence[float]]] = None,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
......@@ -363,10 +362,7 @@ class RandomAffine(Transform):
self.shear = shear
self.interpolation = interpolation
_check_fill_arg(fill)
self.fill = fill
self.fill = _setup_fill_arg(fill)
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
......@@ -404,11 +400,12 @@ class RandomAffine(Transform):
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.affine(
inpt,
**params,
interpolation=self.interpolation,
fill=self.fill,
fill=fill,
center=self.center,
)
......@@ -419,7 +416,7 @@ class RandomCrop(Transform):
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
......@@ -429,12 +426,11 @@ class RandomCrop(Transform):
if pad_if_needed or padding is not None:
if padding is not None:
_check_padding_arg(padding)
_check_fill_arg(fill)
_check_padding_mode_arg(padding_mode)
self.padding = padding
self.pad_if_needed = pad_if_needed
self.fill = fill
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]:
......@@ -483,17 +479,18 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: (PERF) check for speed optimization if we avoid repeated pad calls
fill = self.fill[type(inpt)]
if self.padding is not None:
inpt = F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode)
inpt = F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode)
if self.pad_if_needed:
input_width, input_height = params["input_width"], params["input_height"]
if input_width < self.size[1]:
padding = [self.size[1] - input_width, 0]
inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode)
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
if input_height < self.size[0]:
padding = [0, self.size[0] - input_height]
inpt = F.pad(inpt, padding=padding, fill=self.fill, padding_mode=self.padding_mode)
inpt = F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode)
return F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
......@@ -502,19 +499,18 @@ class RandomPerspective(_RandomApplyTransform):
def __init__(
self,
distortion_scale: float = 0.5,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
p: float = 0.5,
) -> None:
super().__init__(p=p)
_check_fill_arg(fill)
if not (0 <= distortion_scale <= 1):
raise ValueError("Argument distortion_scale value should be between 0 and 1")
self.distortion_scale = distortion_scale
self.interpolation = interpolation
self.fill = fill
self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size
......@@ -546,10 +542,11 @@ class RandomPerspective(_RandomApplyTransform):
return dict(startpoints=startpoints, endpoints=endpoints)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.perspective(
inpt,
**params,
fill=self.fill,
fill=fill,
interpolation=self.interpolation,
)
......@@ -576,17 +573,15 @@ class ElasticTransform(Transform):
self,
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
_check_fill_arg(fill)
self.interpolation = interpolation
self.fill = fill
self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size
......@@ -614,10 +609,11 @@ class ElasticTransform(Transform):
return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self.fill[type(inpt)]
return F.elastic(
inpt,
**params,
fill=self.fill,
fill=fill,
interpolation=self.interpolation,
)
......@@ -789,14 +785,16 @@ class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Union[FillType, Dict[Type, FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]:
......@@ -869,7 +867,8 @@ class FixedSizeCrop(Transform):
)
if params["needs_pad"]:
inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode)
fill = self.fill[type(inpt)]
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt
......
......@@ -499,7 +499,7 @@ def _assert_grid_transform_inputs(
# Check fill
num_channels = get_dimensions(img)[0]
if fill is not None and isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
if fill is not None and isinstance(fill, (tuple, list)) and len(fill) > 1 and len(fill) != num_channels:
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment