Unverified Commit 21deb4d0 authored by Zhengyang Feng's avatar Zhengyang Feng Committed by GitHub
Browse files

Fill color support for tensor affine transforms (#2904)



* Fill color support for tensor affine transforms

* PEP fix

* Docstring changes and float support

* Docstring update for transforms and float type cast

* Cast only for Tensor

* Temporary patch for lack of Union type support, plus an extra unit test

* More plausible bilinear filling for tensors

* Keep things simple & New docstrings

* Fix lint and other issues after merge

* make it in one line

* Docstring and some code modifications

* More tests and corresponding changes for transoforms and docstring changes

* Simplify test configs

* Update test_functional_tensor.py

* Update test_functional_tensor.py

* Move assertions
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent df4003fd
...@@ -552,24 +552,25 @@ class Tester(TransformsTester): ...@@ -552,24 +552,25 @@ class Tester(TransformsTester):
def _test_affine_all_ops(self, tensor, pil_img, scripted_affine): def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
# 4) Test rotation + translation + scale + share # 4) Test rotation + translation + scale + share
test_configs = [ test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]), (45.5, [5, 6], 1.0, [0.0, 0.0], None),
(33, (5, -4), 1.0, [0.0, 0.0]), (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]),
(45, [-5, 4], 1.2, [0.0, 0.0]), (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)),
(33, (-4, -8), 2.0, [0.0, 0.0]), (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]),
(85, (10, -10), 0.7, [0.0, 0.0]), (85, (10, -10), 0.7, [0.0, 0.0], [1, ]),
(0, [0, 0], 1.0, [35.0, ]), (0, [0, 0], 1.0, [35.0, ], (2.0, )),
(-25, [0, 0], 1.2, [0.0, 15.0]), (-25, [0, 0], 1.2, [0.0, 15.0], None),
(-45, [-10, 0], 0.7, [2.0, 5.0]), (-45, [-10, 0], 0.7, [2.0, 5.0], None),
(-45, [-10, -10], 1.2, [4.0, 5.0]), (-45, [-10, -10], 1.2, [4.0, 5.0], None),
(-90, [0, 0], 1.0, [0.0, 0.0]), (-90, [0, 0], 1.0, [0.0, 0.0], None),
] ]
for r in [NEAREST, ]: for r in [NEAREST, ]:
for a, t, s, sh in test_configs: for a, t, s, sh, f in test_configs:
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r) f_pil = int(f[0]) if f is not None and len(f) == 1 else f
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f_pil)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.affine, scripted_affine]: for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r).cpu() out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f).cpu()
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8) out_tensor = out_tensor.to(torch.uint8)
...@@ -582,7 +583,7 @@ class Tester(TransformsTester): ...@@ -582,7 +583,7 @@ class Tester(TransformsTester):
ratio_diff_pixels, ratio_diff_pixels,
tol, tol,
msg="{}: {}\n{} vs \n{}".format( msg="{}: {}\n{} vs \n{}".format(
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] (r, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
) )
) )
...@@ -643,11 +644,12 @@ class Tester(TransformsTester): ...@@ -643,11 +644,12 @@ class Tester(TransformsTester):
for a in range(-180, 180, 17): for a in range(-180, 180, 17):
for e in [True, False]: for e in [True, False]:
for c in centers: for c in centers:
for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]:
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c) f_pil = int(f[0]) if f is not None and len(f) == 1 else f
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]: for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c).cpu() out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu()
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8) out_tensor = out_tensor.to(torch.uint8)
...@@ -657,8 +659,8 @@ class Tester(TransformsTester): ...@@ -657,8 +659,8 @@ class Tester(TransformsTester):
out_pil_tensor.shape, out_pil_tensor.shape,
msg="{}: {} vs {}".format( msg="{}: {} vs {}".format(
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape (img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
) ))
)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels # Tolerance : less than 3% of different pixels
...@@ -666,7 +668,7 @@ class Tester(TransformsTester): ...@@ -666,7 +668,7 @@ class Tester(TransformsTester):
ratio_diff_pixels, ratio_diff_pixels,
0.03, 0.03,
msg="{}: {}\n{} vs \n{}".format( msg="{}: {}\n{} vs \n{}".format(
(img_size, r, dt, a, e, c), (img_size, r, dt, a, e, c, f),
ratio_diff_pixels, ratio_diff_pixels,
out_tensor[0, :7, :7], out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7] out_pil_tensor[0, :7, :7]
...@@ -721,13 +723,16 @@ class Tester(TransformsTester): ...@@ -721,13 +723,16 @@ class Tester(TransformsTester):
def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs): def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
dt = tensor.dtype dt = tensor.dtype
for f in [None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, )]:
for r in [NEAREST, ]: for r in [NEAREST, ]:
for spoints, epoints in test_configs: for spoints, epoints in test_configs:
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) f_pil = int(f[0]) if f is not None and len(f) == 1 else f
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r,
fill=f_pil)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.perspective, scripted_transform]: for fn in [F.perspective, scripted_transform]:
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu() out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r, fill=f).cpu()
if out_tensor.dtype != torch.uint8: if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8) out_tensor = out_tensor.to(torch.uint8)
...@@ -739,7 +744,7 @@ class Tester(TransformsTester): ...@@ -739,7 +744,7 @@ class Tester(TransformsTester):
ratio_diff_pixels, ratio_diff_pixels,
0.05, 0.05,
msg="{}: {}\n{} vs \n{}".format( msg="{}: {}\n{} vs \n{}".format(
(r, dt, spoints, epoints), (f, r, dt, spoints, epoints),
ratio_diff_pixels, ratio_diff_pixels,
out_tensor[0, :7, :7], out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7] out_pil_tensor[0, :7, :7]
......
...@@ -349,9 +349,10 @@ class Tester(TransformsTester): ...@@ -349,9 +349,10 @@ class Tester(TransformsTester):
for translate in [(0.1, 0.2), [0.2, 0.1]]: for translate in [(0.1, 0.2), [0.2, 0.1]]:
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]: for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
for interpolation in [NEAREST, BILINEAR]: for interpolation in [NEAREST, BILINEAR]:
for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
transform = T.RandomAffine( transform = T.RandomAffine(
degrees=degrees, translate=translate, degrees=degrees, translate=translate,
scale=scale, shear=shear, interpolation=interpolation scale=scale, shear=shear, interpolation=interpolation, fill=fill
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
...@@ -369,8 +370,9 @@ class Tester(TransformsTester): ...@@ -369,8 +370,9 @@ class Tester(TransformsTester):
for expand in [True, False]: for expand in [True, False]:
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]: for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
for interpolation in [NEAREST, BILINEAR]: for interpolation in [NEAREST, BILINEAR]:
for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
transform = T.RandomRotation( transform = T.RandomRotation(
degrees=degrees, interpolation=interpolation, expand=expand, center=center degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
...@@ -386,9 +388,11 @@ class Tester(TransformsTester): ...@@ -386,9 +388,11 @@ class Tester(TransformsTester):
for distortion_scale in np.linspace(0.1, 1.0, num=20): for distortion_scale in np.linspace(0.1, 1.0, num=20):
for interpolation in [NEAREST, BILINEAR]: for interpolation in [NEAREST, BILINEAR]:
for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
transform = T.RandomPerspective( transform = T.RandomPerspective(
distortion_scale=distortion_scale, distortion_scale=distortion_scale,
interpolation=interpolation interpolation=interpolation,
fill=fill
) )
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
......
...@@ -557,7 +557,7 @@ def perspective( ...@@ -557,7 +557,7 @@ def perspective(
startpoints: List[List[int]], startpoints: List[List[int]],
endpoints: List[List[int]], endpoints: List[List[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[int] = None fill: Optional[List[float]] = None
) -> Tensor: ) -> Tensor:
"""Perform perspective transform of the given image. """Perform perspective transform of the given image.
The image can be a PIL Image or a Tensor, in which case it is expected The image can be a PIL Image or a Tensor, in which case it is expected
...@@ -573,10 +573,12 @@ def perspective( ...@@ -573,10 +573,12 @@ def perspective(
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
image. If int or float, the value is used for all bands respectively. image. If int or float, the value is used for all bands respectively.
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor This option is supported for PIL image and Tensor inputs.
input. Fill value for the area outside the transform in the output image is always 0. In torchscript mode single int/float value is not supported, please use a tuple
or list of length 1: ``[value, ]``.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
Returns: Returns:
PIL Image or Tensor: transformed Image. PIL Image or Tensor: transformed Image.
...@@ -871,7 +873,7 @@ def _get_inverse_affine_matrix( ...@@ -871,7 +873,7 @@ def _get_inverse_affine_matrix(
def rotate( def rotate(
img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, center: Optional[List[int]] = None, expand: bool = False, center: Optional[List[int]] = None,
fill: Optional[int] = None, resample: Optional[int] = None fill: Optional[List[float]] = None, resample: Optional[int] = None
) -> Tensor: ) -> Tensor:
"""Rotate the image by angle. """Rotate the image by angle.
The image can be a PIL Image or a Tensor, in which case it is expected The image can be a PIL Image or a Tensor, in which case it is expected
...@@ -890,13 +892,12 @@ def rotate( ...@@ -890,13 +892,12 @@ def rotate(
Note that the expand flag assumes rotation around the center and no translation. Note that the expand flag assumes rotation around the center and no translation.
center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner. center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner.
Default is the center of the image. Default is the center of the image.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
image. If int or float, the value is used for all bands respectively. image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. This option is supported for PIL image and Tensor inputs.
This option is not supported for Tensor input. Fill value for the area outside the transform in the output In torchscript mode single int/float value is not supported, please use a tuple
image is always 0. or list of length 1: ``[value, ]``.
resample (int, optional): deprecated argument and will be removed since v0.10.0. If input is PIL Image, the options is only available for ``Pillow>=5.2.0``.
Please use `arg`:interpolation: instead.
Returns: Returns:
PIL Image or Tensor: Rotated image. PIL Image or Tensor: Rotated image.
...@@ -945,8 +946,8 @@ def rotate( ...@@ -945,8 +946,8 @@ def rotate(
def affine( def affine(
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[int] = None, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None,
resample: Optional[int] = None, fillcolor: Optional[int] = None resample: Optional[int] = None, fillcolor: Optional[List[float]] = None
) -> Tensor: ) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant. """Apply affine transformation on the image keeping image center invariant.
The image can be a PIL Image or a Tensor, in which case it is expected The image can be a PIL Image or a Tensor, in which case it is expected
...@@ -964,10 +965,13 @@ def affine( ...@@ -964,10 +965,13 @@ def affine(
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0). fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
This option is not supported for Tensor input. Fill value for the area outside the transform in the output image. If int or float, the value is used for all bands respectively.
image is always 0. This option is supported for PIL image and Tensor inputs.
fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0. In torchscript mode single int/float value is not supported, please use a tuple
or list of length 1: ``[value, ]``.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
fillcolor (sequence, int, float): deprecated argument and will be removed since v0.10.0.
Please use `arg`:fill: instead. Please use `arg`:fill: instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0. resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:interpolation: instead. Please use `arg`:interpolation: instead.
......
...@@ -465,11 +465,14 @@ def _parse_fill(fill, img, min_pil_version, name="fillcolor"): ...@@ -465,11 +465,14 @@ def _parse_fill(fill, img, min_pil_version, name="fillcolor"):
fill = 0 fill = 0
if isinstance(fill, (int, float)) and num_bands > 1: if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands) fill = tuple([fill] * num_bands)
if not isinstance(fill, (int, float)) and len(fill) != num_bands: if isinstance(fill, (list, tuple)):
if len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of " msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})") "bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands)) raise ValueError(msg.format(len(fill), num_bands))
fill = tuple(fill)
return {name: fill} return {name: fill}
......
...@@ -835,7 +835,7 @@ def _assert_grid_transform_inputs( ...@@ -835,7 +835,7 @@ def _assert_grid_transform_inputs(
img: Tensor, img: Tensor,
matrix: Optional[List[float]], matrix: Optional[List[float]],
interpolation: str, interpolation: str,
fill: Optional[int], fill: Optional[List[float]],
supported_interpolation_modes: List[str], supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None, coeffs: Optional[List[float]] = None,
): ):
...@@ -851,8 +851,15 @@ def _assert_grid_transform_inputs( ...@@ -851,8 +851,15 @@ def _assert_grid_transform_inputs(
if coeffs is not None and len(coeffs) != 8: if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values") raise ValueError("Argument coeffs should have 8 float values")
if fill is not None and not (isinstance(fill, (int, float)) and fill == 0): if fill is not None and not isinstance(fill, (int, float, tuple, list)):
warnings.warn("Argument fill is not supported for Tensor input. Fill value is zero") warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill
num_channels = _get_image_num_channels(img)
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = ("The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes: if interpolation not in supported_interpolation_modes:
raise ValueError("Interpolation mode '{}' is unsupported with Tensor input".format(interpolation)) raise ValueError("Interpolation mode '{}' is unsupported with Tensor input".format(interpolation))
...@@ -887,15 +894,34 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp ...@@ -887,15 +894,34 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
return img return img
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor: def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype, ]) img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype, ])
if img.shape[0] > 1: if img.shape[0] > 1:
# Apply same grid to a batch of images # Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3]) grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, dummy), dim=1)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False) img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == 'nearest':
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img return img
...@@ -923,7 +949,7 @@ def _gen_affine_grid( ...@@ -923,7 +949,7 @@ def _gen_affine_grid(
def affine( def affine(
img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[int] = None img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None
) -> Tensor: ) -> Tensor:
"""PRIVATE METHOD. Apply affine transformation on the Tensor image keeping image center invariant. """PRIVATE METHOD. Apply affine transformation on the Tensor image keeping image center invariant.
...@@ -936,8 +962,8 @@ def affine( ...@@ -936,8 +962,8 @@ def affine(
img (Tensor): image to be rotated. img (Tensor): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation.
interpolation (str): An optional resampling filter. Default is "nearest". Other supported values: "bilinear". interpolation (str): An optional resampling filter. Default is "nearest". Other supported values: "bilinear".
fill (int, optional): this option is not supported for Tensor input. Fill value for the area outside the fill (sequence or int or float, optional): Optional fill value, default None.
transform in the output image is always 0. If None, fill with 0.
Returns: Returns:
Tensor: Transformed image. Tensor: Transformed image.
...@@ -949,7 +975,7 @@ def affine( ...@@ -949,7 +975,7 @@ def affine(
shape = img.shape shape = img.shape
# grid will be generated on the same device as theta and img # grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
return _apply_grid_transform(img, grid, interpolation) return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
...@@ -979,7 +1005,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] ...@@ -979,7 +1005,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
def rotate( def rotate(
img: Tensor, matrix: List[float], interpolation: str = "nearest", img: Tensor, matrix: List[float], interpolation: str = "nearest",
expand: bool = False, fill: Optional[int] = None expand: bool = False, fill: Optional[List[float]] = None
) -> Tensor: ) -> Tensor:
"""PRIVATE METHOD. Rotate the Tensor image by angle. """PRIVATE METHOD. Rotate the Tensor image by angle.
...@@ -997,8 +1023,8 @@ def rotate( ...@@ -997,8 +1023,8 @@ def rotate(
If true, expands the output image to make it large enough to hold the entire rotated image. If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image. If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation. Note that the expand flag assumes rotation around the center and no translation.
fill (n-tuple or int or float): this option is not supported for Tensor input. fill (sequence or int or float, optional): Optional fill value, default None.
Fill value for the area outside the transform in the output image is always 0. If None, fill with 0.
Returns: Returns:
Tensor: Rotated image. Tensor: Rotated image.
...@@ -1013,7 +1039,8 @@ def rotate( ...@@ -1013,7 +1039,8 @@ def rotate(
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img # grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
return _apply_grid_transform(img, grid, interpolation)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device): def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device):
...@@ -1050,7 +1077,7 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, ...@@ -1050,7 +1077,7 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
def perspective( def perspective(
img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[int] = None img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None
) -> Tensor: ) -> Tensor:
"""PRIVATE METHOD. Perform perspective transform of the given Tensor image. """PRIVATE METHOD. Perform perspective transform of the given Tensor image.
...@@ -1063,8 +1090,8 @@ def perspective( ...@@ -1063,8 +1090,8 @@ def perspective(
img (Tensor): Image to be transformed. img (Tensor): Image to be transformed.
perspective_coeffs (list of float): perspective transformation coefficients. perspective_coeffs (list of float): perspective transformation coefficients.
interpolation (str): Interpolation type. Default, "bilinear". interpolation (str): Interpolation type. Default, "bilinear".
fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area fill (sequence or int or float, optional): Optional fill value, default None.
outside the transform in the output image is always 0. If None, fill with 0.
Returns: Returns:
Tensor: transformed image. Tensor: transformed image.
...@@ -1084,7 +1111,7 @@ def perspective( ...@@ -1084,7 +1111,7 @@ def perspective(
ow, oh = img.shape[-1], img.shape[-2] ow, oh = img.shape[-1], img.shape[-2]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32 dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device) grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
return _apply_grid_transform(img, grid, interpolation) return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
......
...@@ -667,10 +667,10 @@ class RandomPerspective(torch.nn.Module): ...@@ -667,10 +667,10 @@ class RandomPerspective(torch.nn.Module):
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
image. If int or float, the value is used for all bands respectively. Default is 0. image. If int or float, the value is used for all bands respectively.
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor This option is supported for PIL image and Tensor inputs.
input. Fill value for the area outside the transform in the output image is always 0. If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
""" """
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0): def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
...@@ -697,10 +697,18 @@ class RandomPerspective(torch.nn.Module): ...@@ -697,10 +697,18 @@ class RandomPerspective(torch.nn.Module):
Returns: Returns:
PIL Image or Tensor: Randomly transformed image. PIL Image or Tensor: Randomly transformed image.
""" """
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
if torch.rand(1) < self.p: if torch.rand(1) < self.p:
width, height = F._get_image_size(img) width, height = F._get_image_size(img)
startpoints, endpoints = self.get_params(width, height, self.distortion_scale) startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill) return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
return img return img
@staticmethod @staticmethod
...@@ -1157,11 +1165,10 @@ class RandomRotation(torch.nn.Module): ...@@ -1157,11 +1165,10 @@ class RandomRotation(torch.nn.Module):
Note that the expand flag assumes rotation around the center and no translation. Note that the expand flag assumes rotation around the center and no translation.
center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner. center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image. Default is the center of the image.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated fill (sequence or int or float, optional): Pixel fill value for the area outside the rotated
image. If int or float, the value is used for all bands respectively. image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. This option is supported for PIL image and Tensor inputs.
This option is not supported for Tensor input. Fill value for the area outside the transform in the output If input is PIL Image, the options is only available for ``Pillow>=5.2.0``.
image is always 0.
resample (int, optional): deprecated argument and will be removed since v0.10.0. resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:interpolation: instead. Please use `arg`:interpolation: instead.
...@@ -1216,8 +1223,15 @@ class RandomRotation(torch.nn.Module): ...@@ -1216,8 +1223,15 @@ class RandomRotation(torch.nn.Module):
Returns: Returns:
PIL Image or Tensor: Rotated image. PIL Image or Tensor: Rotated image.
""" """
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
angle = self.get_params(self.degrees) angle = self.get_params(self.degrees)
return F.rotate(img, angle, self.interpolation, self.expand, self.center, self.fill)
return F.rotate(img, angle, self.resample, self.expand, self.center, fill)
def __repr__(self): def __repr__(self):
interpolate_str = self.interpolation.value interpolate_str = self.interpolation.value
...@@ -1257,10 +1271,11 @@ class RandomAffine(torch.nn.Module): ...@@ -1257,10 +1271,11 @@ class RandomAffine(torch.nn.Module):
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor image. If int or float, the value is used for all bands respectively.
input. Fill value for the area outside the transform in the output image is always 0. This option is supported for PIL image and Tensor inputs.
fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0. If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
fillcolor (sequence or int or float, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:fill: instead. Please use `arg`:fill: instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0. resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:interpolation: instead. Please use `arg`:interpolation: instead.
...@@ -1363,11 +1378,18 @@ class RandomAffine(torch.nn.Module): ...@@ -1363,11 +1378,18 @@ class RandomAffine(torch.nn.Module):
Returns: Returns:
PIL Image or Tensor: Affine transformed image. PIL Image or Tensor: Affine transformed image.
""" """
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
img_size = F._get_image_size(img) img_size = F._get_image_size(img)
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
return F.affine(img, *ret, interpolation=self.interpolation, fill=self.fill)
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill)
def __repr__(self): def __repr__(self):
s = '{name}(degrees={degrees}' s = '{name}(degrees={degrees}'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment