Unverified Commit 2eba1f04 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Unified inputs for grayscale ops and transforms (#2586)

* [WIP] Unify ops Grayscale and RandomGrayscale

* Unified inputs for grayscale op and transforms
- deprecated F.to_grayscale in favor of F.rgb_to_grayscale

* Fixes bug with fp input

* [WIP] Updated code according to review

* Removed unused import
parent 279fca56
......@@ -350,9 +350,12 @@ class TransformsTester(unittest.TestCase):
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.cpu().equal(pil_tensor), msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, method="mean"):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
err = getattr(torch, method)(tensor - pil_tensor).item()
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean"):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
err = getattr(torch, agg_method)(tensor - pil_tensor).item()
self.assertTrue(
err < tol,
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
......
......@@ -194,18 +194,29 @@ class Tester(TransformsTester):
def test_adjustments_cuda(self):
self._test_adjustments("cuda")
def _test_rgb_to_grayscale(self, device):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
img_tensor, pil_img = self._create_data(32, 34, device=device)
for num_output_channels in (3, 1):
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
if num_output_channels == 1:
print(gray_tensor.shape)
self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.assertTrue(s_gray_tensor.equal(gray_tensor))
def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001)
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
self._test_rgb_to_grayscale("cpu")
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
def test_rgb_to_grayscale_cuda(self):
self._test_rgb_to_grayscale("cuda")
def _test_center_crop(self, device):
script_center_crop = torch.jit.script(F.center_crop)
......
......@@ -13,7 +13,7 @@ from common_utils import TransformsTester
class Tester(TransformsTester):
def _test_functional_geom_op(self, func, fn_kwargs):
def _test_functional_op(self, func, fn_kwargs):
if fn_kwargs is None:
fn_kwargs = {}
tensor, pil_img = self._create_data(height=10, width=10)
......@@ -21,7 +21,7 @@ class Tester(TransformsTester):
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
def _test_class_geom_op(self, method, meth_kwargs=None):
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None:
meth_kwargs = {}
......@@ -35,21 +35,24 @@ class Tester(TransformsTester):
transformed_tensor = f(tensor)
torch.manual_seed(12)
transformed_pil_img = f(pil_img)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_geom_op(func, fn_kwargs)
self._test_class_geom_op(method, meth_kwargs)
def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_op(func, fn_kwargs)
self._test_class_op(method, meth_kwargs)
def test_random_horizontal_flip(self):
self._test_geom_op('hflip', 'RandomHorizontalFlip')
self._test_op('hflip', 'RandomHorizontalFlip')
def test_random_vertical_flip(self):
self._test_geom_op('vflip', 'RandomVerticalFlip')
self._test_op('vflip', 'RandomVerticalFlip')
def test_adjustments(self):
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
......@@ -80,22 +83,22 @@ class Tester(TransformsTester):
def test_pad(self):
# Test functional.pad (PIL and Tensor) with padding as single int
self._test_functional_geom_op(
self._test_functional_op(
"pad", fn_kwargs={"padding": 2, "fill": 0, "padding_mode": "constant"}
)
# Test functional.pad and transforms.Pad with padding as [int, ]
fn_kwargs = meth_kwargs = {"padding": [2, ], "fill": 0, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as list
fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
# Test functional.pad and transforms.Pad with padding as tuple
fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": "constant"}
self._test_geom_op(
self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
......@@ -103,7 +106,7 @@ class Tester(TransformsTester):
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
# Test transforms.RandomCrop with size and padding as tuple
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
self._test_geom_op(
self._test_op(
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
......@@ -120,17 +123,17 @@ class Tester(TransformsTester):
for padding_config in padding_configs:
config = dict(padding_config)
config["size"] = size
self._test_class_geom_op("RandomCrop", config)
self._test_class_op("RandomCrop", config)
def test_center_crop(self):
fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), }
self._test_geom_op(
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = {"output_size": (5,)}
meth_kwargs = {"size": (5, )}
self._test_geom_op(
self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
......@@ -149,7 +152,7 @@ class Tester(TransformsTester):
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None:
fn_kwargs = {}
if meth_kwargs is None:
......@@ -178,37 +181,37 @@ class Tester(TransformsTester):
def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
def test_ten_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output(
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
......@@ -312,6 +315,24 @@ class Tester(TransformsTester):
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
def test_to_grayscale(self):
meth_kwargs = {"num_output_channels": 1}
tol = 1.0 + 1e-10
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
meth_kwargs = {"num_output_channels": 3}
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
meth_kwargs = {}
self._test_class_op(
"RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
if __name__ == '__main__':
unittest.main()
......@@ -32,6 +32,13 @@ def _get_image_size(img: Tensor) -> List[int]:
return F_pil._get_image_size(img)
def _get_image_num_channels(img: Tensor) -> int:
if isinstance(img, torch.Tensor):
return F_t._get_image_num_channels(img)
return F_pil._get_image_num_channels(img)
@torch.jit.unused
def _is_numpy(img: Any) -> bool:
return isinstance(img, np.ndarray)
......@@ -951,11 +958,13 @@ def affine(
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
@torch.jit.unused
def to_grayscale(img, num_output_channels=1):
"""Convert image to grayscale version of image.
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
Args:
img (PIL Image): Image to be converted to grayscale.
img (PIL Image): PIL Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns:
PIL Image: Grayscale version of the image.
......@@ -963,20 +972,35 @@ def to_grayscale(img, num_output_channels=1):
if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')
if isinstance(img, Image.Image):
return F_pil.to_grayscale(img, num_output_channels)
return img
raise TypeError("Input should be PIL Image")
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
"""Convert RGB image to grayscale version of image.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Note:
Please, note that this method supports only RGB images as input. For inputs in other color spaces,
please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
Args:
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns:
PIL Image or Tensor: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel
if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not isinstance(img, torch.Tensor):
return F_pil.to_grayscale(img, num_output_channels)
return F_t.rgb_to_grayscale(img, num_output_channels)
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
......
......@@ -26,6 +26,13 @@ def _get_image_size(img: Any) -> List[int]:
raise TypeError("Unexpected type {}".format(type(img)))
@torch.jit.unused
def _get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == 'L' else 3
raise TypeError("Unexpected type {}".format(type(img)))
@torch.jit.unused
def hflip(img):
"""Horizontally flip the given PIL Image.
......@@ -480,3 +487,33 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None)
opts = _parse_fill(fill, img, '5.0.0')
return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
@torch.jit.unused
def to_grayscale(img, num_output_channels):
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
Args:
img (PIL Image): Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns:
PIL Image: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel
if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')
return img
......@@ -18,6 +18,15 @@ def _get_image_size(img: Tensor) -> List[int]:
raise TypeError("Unexpected type {}".format(type(img)))
def _get_image_num_channels(img: Tensor) -> int:
if img.ndim == 2:
return 1
elif img.ndim > 2:
return img.shape[-3]
raise TypeError("Unexpected type {}".format(type(img)))
def vflip(img: Tensor) -> Tensor:
"""Vertically flip the given the Image Tensor.
......@@ -67,22 +76,41 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
return img[..., top:top + height, left:left + width]
def rgb_to_grayscale(img: Tensor) -> Tensor:
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
"""Convert the given RGB Image Tensor to Grayscale.
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
is L = R * 0.2989 + G * 0.5870 + B * 0.1140
Args:
img (Tensor): Image to be converted to Grayscale in the form [C, H, W].
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns:
Tensor: Grayscale image.
Tensor: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel
if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if img.shape[0] != 3:
raise TypeError('Input Image does not contain 3 Channels')
if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
c = img.shape[-3]
if c != 3:
raise TypeError("Input image tensor should 3 channels, but found {}".format(c))
if num_output_channels not in (1, 3):
raise ValueError('num_output_channels should be either 1 or 3')
r, g, b = img.unbind(dim=-3)
# This implementation closely follows the TF one:
# https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
l_img = l_img.unsqueeze(dim=-3)
if num_output_channels == 3:
return l_img.expand(img.shape)
return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)
return l_img
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
......@@ -373,8 +401,8 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
bound = 1.0 if img1.is_floating_point() else 255.0
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb2hsv(img):
......
......@@ -1354,8 +1354,11 @@ class RandomAffine(torch.nn.Module):
return s.format(name=self.__class__.__name__, **d)
class Grayscale(object):
class Grayscale(torch.nn.Module):
"""Convert image to grayscale.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
num_output_channels (int): (1 or 3) number of channels desired for output image
......@@ -1368,30 +1371,34 @@ class Grayscale(object):
"""
def __init__(self, num_output_channels=1):
super().__init__()
self.num_output_channels = num_output_channels
def __call__(self, img):
def forward(self, img: Tensor) -> Tensor:
"""
Args:
img (PIL Image): Image to be converted to grayscale.
img (PIL Image or Tensor): Image to be converted to grayscale.
Returns:
PIL Image: Randomly grayscaled image.
PIL Image or Tensor: Grayscaled image.
"""
return F.to_grayscale(img, num_output_channels=self.num_output_channels)
return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
def __repr__(self):
return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
class RandomGrayscale(object):
class RandomGrayscale(torch.nn.Module):
"""Randomly convert image to grayscale with a probability of p (default 0.1).
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability that image should be converted to grayscale.
Returns:
PIL Image: Grayscale version of the input image with probability p and unchanged
PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
with probability (1-p).
- If input image is 1 channel: grayscale version is 1 channel
- If input image is 3 channel: grayscale version is 3 channel with r == g == b
......@@ -1399,19 +1406,20 @@ class RandomGrayscale(object):
"""
def __init__(self, p=0.1):
super().__init__()
self.p = p
def __call__(self, img):
def forward(self, img: Tensor) -> Tensor:
"""
Args:
img (PIL Image): Image to be converted to grayscale.
img (PIL Image or Tensor): Image to be converted to grayscale.
Returns:
PIL Image: Randomly grayscaled image.
PIL Image or Tensor: Randomly grayscaled image.
"""
num_output_channels = 1 if img.mode == 'L' else 3
if random.random() < self.p:
return F.to_grayscale(img, num_output_channels=num_output_channels)
num_output_channels = F._get_image_num_channels(img)
if torch.rand(1) < self.p:
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
return img
def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment