Unverified Commit 21deb4d0 authored by Zhengyang Feng's avatar Zhengyang Feng Committed by GitHub
Browse files

Fill color support for tensor affine transforms (#2904)



* Fill color support for tensor affine transforms

* PEP fix

* Docstring changes and float support

* Docstring update for transforms and float type cast

* Cast only for Tensor

* Temporary patch for lack of Union type support, plus an extra unit test

* More plausible bilinear filling for tensors

* Keep things simple & New docstrings

* Fix lint and other issues after merge

* make it in one line

* Docstring and some code modifications

* More tests and corresponding changes for transoforms and docstring changes

* Simplify test configs

* Update test_functional_tensor.py

* Update test_functional_tensor.py

* Move assertions
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent df4003fd
......@@ -552,24 +552,25 @@ class Tester(TransformsTester):
def _test_affine_all_ops(self, tensor, pil_img, scripted_affine):
# 4) Test rotation + translation + scale + share
test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]),
(33, (5, -4), 1.0, [0.0, 0.0]),
(45, [-5, 4], 1.2, [0.0, 0.0]),
(33, (-4, -8), 2.0, [0.0, 0.0]),
(85, (10, -10), 0.7, [0.0, 0.0]),
(0, [0, 0], 1.0, [35.0, ]),
(-25, [0, 0], 1.2, [0.0, 15.0]),
(-45, [-10, 0], 0.7, [2.0, 5.0]),
(-45, [-10, -10], 1.2, [4.0, 5.0]),
(-90, [0, 0], 1.0, [0.0, 0.0]),
(45.5, [5, 6], 1.0, [0.0, 0.0], None),
(33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]),
(45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)),
(33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]),
(85, (10, -10), 0.7, [0.0, 0.0], [1, ]),
(0, [0, 0], 1.0, [35.0, ], (2.0, )),
(-25, [0, 0], 1.2, [0.0, 15.0], None),
(-45, [-10, 0], 0.7, [2.0, 5.0], None),
(-45, [-10, -10], 1.2, [4.0, 5.0], None),
(-90, [0, 0], 1.0, [0.0, 0.0], None),
]
for r in [NEAREST, ]:
for a, t, s, sh in test_configs:
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r)
for a, t, s, sh, f in test_configs:
f_pil = int(f[0]) if f is not None and len(f) == 1 else f
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f_pil)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r).cpu()
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r, fill=f).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
......@@ -582,7 +583,7 @@ class Tester(TransformsTester):
ratio_diff_pixels,
tol,
msg="{}: {}\n{} vs \n{}".format(
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
(r, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
......@@ -643,35 +644,36 @@ class Tester(TransformsTester):
for a in range(-180, 180, 17):
for e in [True, False]:
for c in centers:
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
)
)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
0.03,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, dt, a, e, c),
for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]:
f_pil = int(f[0]) if f is not None and len(f) == 1 else f
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape
))
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 3% of different pixels
self.assertLess(
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
0.03,
msg="{}: {}\n{} vs \n{}".format(
(img_size, r, dt, a, e, c, f),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
)
)
def test_rotate(self):
# Tests on square image
......@@ -721,30 +723,33 @@ class Tester(TransformsTester):
def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs):
dt = tensor.dtype
for r in [NEAREST, ]:
for spoints, epoints in test_configs:
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for f in [None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, )]:
for r in [NEAREST, ]:
for spoints, epoints in test_configs:
f_pil = int(f[0]) if f is not None and len(f) == 1 else f
out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r,
fill=f_pil)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.perspective, scripted_transform]:
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r).cpu()
for fn in [F.perspective, scripted_transform]:
out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=r, fill=f).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels,
0.05,
msg="{}: {}\n{} vs \n{}".format(
(r, dt, spoints, epoints),
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
0.05,
msg="{}: {}\n{} vs \n{}".format(
(f, r, dt, spoints, epoints),
ratio_diff_pixels,
out_tensor[0, :7, :7],
out_pil_tensor[0, :7, :7]
)
)
)
def test_perspective(self):
......
......@@ -349,14 +349,15 @@ class Tester(TransformsTester):
for translate in [(0.1, 0.2), [0.2, 0.1]]:
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomAffine(
degrees=degrees, translate=translate,
scale=scale, shear=shear, interpolation=interpolation
)
s_transform = torch.jit.script(transform)
for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
transform = T.RandomAffine(
degrees=degrees, translate=translate,
scale=scale, shear=shear, interpolation=interpolation, fill=fill
)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_random_affine.pt"))
......@@ -369,13 +370,14 @@ class Tester(TransformsTester):
for expand in [True, False]:
for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]:
for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomRotation(
degrees=degrees, interpolation=interpolation, expand=expand, center=center
)
s_transform = torch.jit.script(transform)
for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
transform = T.RandomRotation(
degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill
)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_random_rotate.pt"))
......@@ -386,14 +388,16 @@ class Tester(TransformsTester):
for distortion_scale in np.linspace(0.1, 1.0, num=20):
for interpolation in [NEAREST, BILINEAR]:
transform = T.RandomPerspective(
distortion_scale=distortion_scale,
interpolation=interpolation
)
s_transform = torch.jit.script(transform)
for fill in [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
transform = T.RandomPerspective(
distortion_scale=distortion_scale,
interpolation=interpolation,
fill=fill
)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_perspective.pt"))
......
......@@ -557,7 +557,7 @@ def perspective(
startpoints: List[List[int]],
endpoints: List[List[int]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[int] = None
fill: Optional[List[float]] = None
) -> Tensor:
"""Perform perspective transform of the given image.
The image can be a PIL Image or a Tensor, in which case it is expected
......@@ -573,10 +573,12 @@ def perspective(
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
image. If int or float, the value is used for all bands respectively.
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor
input. Fill value for the area outside the transform in the output image is always 0.
This option is supported for PIL image and Tensor inputs.
In torchscript mode single int/float value is not supported, please use a tuple
or list of length 1: ``[value, ]``.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
Returns:
PIL Image or Tensor: transformed Image.
......@@ -871,7 +873,7 @@ def _get_inverse_affine_matrix(
def rotate(
img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, center: Optional[List[int]] = None,
fill: Optional[int] = None, resample: Optional[int] = None
fill: Optional[List[float]] = None, resample: Optional[int] = None
) -> Tensor:
"""Rotate the image by angle.
The image can be a PIL Image or a Tensor, in which case it is expected
......@@ -890,13 +892,12 @@ def rotate(
Note that the expand flag assumes rotation around the center and no translation.
center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner.
Default is the center of the image.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
image is always 0.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:interpolation: instead.
This option is supported for PIL image and Tensor inputs.
In torchscript mode single int/float value is not supported, please use a tuple
or list of length 1: ``[value, ]``.
If input is PIL Image, the options is only available for ``Pillow>=5.2.0``.
Returns:
PIL Image or Tensor: Rotated image.
......@@ -945,8 +946,8 @@ def rotate(
def affine(
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[int] = None,
resample: Optional[int] = None, fillcolor: Optional[int] = None
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None,
resample: Optional[int] = None, fillcolor: Optional[List[float]] = None
) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant.
The image can be a PIL Image or a Tensor, in which case it is expected
......@@ -964,10 +965,13 @@ def affine(
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0).
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
image is always 0.
fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0.
fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
image. If int or float, the value is used for all bands respectively.
This option is supported for PIL image and Tensor inputs.
In torchscript mode single int/float value is not supported, please use a tuple
or list of length 1: ``[value, ]``.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
fillcolor (sequence, int, float): deprecated argument and will be removed since v0.10.0.
Please use `arg`:fill: instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:interpolation: instead.
......
......@@ -465,10 +465,13 @@ def _parse_fill(fill, img, min_pil_version, name="fillcolor"):
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
if isinstance(fill, (list, tuple)):
if len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
fill = tuple(fill)
return {name: fill}
......
......@@ -835,7 +835,7 @@ def _assert_grid_transform_inputs(
img: Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: Optional[int],
fill: Optional[List[float]],
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
):
......@@ -851,8 +851,15 @@ def _assert_grid_transform_inputs(
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None and not (isinstance(fill, (int, float)) and fill == 0):
warnings.warn("Argument fill is not supported for Tensor input. Fill value is zero")
if fill is not None and not isinstance(fill, (int, float, tuple, list)):
warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill
num_channels = _get_image_num_channels(img)
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = ("The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_channels))
if interpolation not in supported_interpolation_modes:
raise ValueError("Interpolation mode '{}' is unsupported with Tensor input".format(interpolation))
......@@ -887,15 +894,34 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
return img
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor:
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype, ])
if img.shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(img.shape[0], grid.shape[1], grid.shape[2], grid.shape[3])
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, dummy), dim=1)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == 'nearest':
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img
......@@ -923,7 +949,7 @@ def _gen_affine_grid(
def affine(
img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[int] = None
img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None
) -> Tensor:
"""PRIVATE METHOD. Apply affine transformation on the Tensor image keeping image center invariant.
......@@ -936,8 +962,8 @@ def affine(
img (Tensor): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation.
interpolation (str): An optional resampling filter. Default is "nearest". Other supported values: "bilinear".
fill (int, optional): this option is not supported for Tensor input. Fill value for the area outside the
transform in the output image is always 0.
fill (sequence or int or float, optional): Optional fill value, default None.
If None, fill with 0.
Returns:
Tensor: Transformed image.
......@@ -949,7 +975,7 @@ def affine(
shape = img.shape
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2])
return _apply_grid_transform(img, grid, interpolation)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
......@@ -979,7 +1005,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
def rotate(
img: Tensor, matrix: List[float], interpolation: str = "nearest",
expand: bool = False, fill: Optional[int] = None
expand: bool = False, fill: Optional[List[float]] = None
) -> Tensor:
"""PRIVATE METHOD. Rotate the Tensor image by angle.
......@@ -997,8 +1023,8 @@ def rotate(
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
fill (n-tuple or int or float): this option is not supported for Tensor input.
Fill value for the area outside the transform in the output image is always 0.
fill (sequence or int or float, optional): Optional fill value, default None.
If None, fill with 0.
Returns:
Tensor: Rotated image.
......@@ -1013,7 +1039,8 @@ def rotate(
theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3)
# grid will be generated on the same device as theta and img
grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh)
return _apply_grid_transform(img, grid, interpolation)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device):
......@@ -1050,7 +1077,7 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
def perspective(
img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[int] = None
img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None
) -> Tensor:
"""PRIVATE METHOD. Perform perspective transform of the given Tensor image.
......@@ -1063,8 +1090,8 @@ def perspective(
img (Tensor): Image to be transformed.
perspective_coeffs (list of float): perspective transformation coefficients.
interpolation (str): Interpolation type. Default, "bilinear".
fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area
outside the transform in the output image is always 0.
fill (sequence or int or float, optional): Optional fill value, default None.
If None, fill with 0.
Returns:
Tensor: transformed image.
......@@ -1084,7 +1111,7 @@ def perspective(
ow, oh = img.shape[-1], img.shape[-2]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device)
return _apply_grid_transform(img, grid, interpolation)
return _apply_grid_transform(img, grid, interpolation, fill=fill)
def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
......
......@@ -667,10 +667,10 @@ class RandomPerspective(torch.nn.Module):
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively. Default is 0.
This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor
input. Fill value for the area outside the transform in the output image is always 0.
fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
image. If int or float, the value is used for all bands respectively.
This option is supported for PIL image and Tensor inputs.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
"""
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode.BILINEAR, fill=0):
......@@ -697,10 +697,18 @@ class RandomPerspective(torch.nn.Module):
Returns:
PIL Image or Tensor: Randomly transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
if torch.rand(1) < self.p:
width, height = F._get_image_size(img)
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill)
return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
return img
@staticmethod
......@@ -1157,11 +1165,10 @@ class RandomRotation(torch.nn.Module):
Note that the expand flag assumes rotation around the center and no translation.
center (list or tuple, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
fill (sequence or int or float, optional): Pixel fill value for the area outside the rotated
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0.
This option is not supported for Tensor input. Fill value for the area outside the transform in the output
image is always 0.
This option is supported for PIL image and Tensor inputs.
If input is PIL Image, the options is only available for ``Pillow>=5.2.0``.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:interpolation: instead.
......@@ -1216,8 +1223,15 @@ class RandomRotation(torch.nn.Module):
Returns:
PIL Image or Tensor: Rotated image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
angle = self.get_params(self.degrees)
return F.rotate(img, angle, self.interpolation, self.expand, self.center, self.fill)
return F.rotate(img, angle, self.resample, self.expand, self.center, fill)
def __repr__(self):
interpolate_str = self.interpolation.value
......@@ -1257,10 +1271,11 @@ class RandomAffine(torch.nn.Module):
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area
outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor
input. Fill value for the area outside the transform in the output image is always 0.
fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0.
fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
image. If int or float, the value is used for all bands respectively.
This option is supported for PIL image and Tensor inputs.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
fillcolor (sequence or int or float, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:fill: instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use `arg`:interpolation: instead.
......@@ -1363,11 +1378,18 @@ class RandomAffine(torch.nn.Module):
Returns:
PIL Image or Tensor: Affine transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
img_size = F._get_image_size(img)
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
return F.affine(img, *ret, interpolation=self.interpolation, fill=self.fill)
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill)
def __repr__(self):
s = '{name}(degrees={degrees}'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment