Unverified Commit 2eba1f04 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Unified inputs for grayscale ops and transforms (#2586)

* [WIP] Unify ops Grayscale and RandomGrayscale

* Unified inputs for grayscale op and transforms
- deprecated F.to_grayscale in favor of F.rgb_to_grayscale

* Fixes bug with fp input

* [WIP] Updated code according to review

* Removed unused import
parent 279fca56
...@@ -350,9 +350,12 @@ class TransformsTester(unittest.TestCase): ...@@ -350,9 +350,12 @@ class TransformsTester(unittest.TestCase):
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor) msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
self.assertTrue(tensor.cpu().equal(pil_tensor), msg) self.assertTrue(tensor.cpu().equal(pil_tensor), msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, method="mean"): def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_method="mean"):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor) np_pil_image = np.array(pil_image)
err = getattr(torch, method)(tensor - pil_tensor).item() if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
err = getattr(torch, agg_method)(tensor - pil_tensor).item()
self.assertTrue( self.assertTrue(
err < tol, err < tol,
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10]) msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
......
...@@ -194,18 +194,29 @@ class Tester(TransformsTester): ...@@ -194,18 +194,29 @@ class Tester(TransformsTester):
def test_adjustments_cuda(self): def test_adjustments_cuda(self):
self._test_adjustments("cuda") self._test_adjustments("cuda")
def _test_rgb_to_grayscale(self, device):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
img_tensor, pil_img = self._create_data(32, 34, device=device)
for num_output_channels in (3, 1):
gray_pil_image = F.rgb_to_grayscale(pil_img, num_output_channels=num_output_channels)
gray_tensor = F.rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
if num_output_channels == 1:
print(gray_tensor.shape)
self.approxEqualTensorToPIL(gray_tensor.float(), gray_pil_image, tol=1.0 + 1e-10, agg_method="max")
s_gray_tensor = script_rgb_to_grayscale(img_tensor, num_output_channels=num_output_channels)
self.assertTrue(s_gray_tensor.equal(gray_tensor))
def test_rgb_to_grayscale(self): def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale) self._test_rgb_to_grayscale("cpu")
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone() @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int) def test_rgb_to_grayscale_cuda(self):
grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int) self._test_rgb_to_grayscale("cuda")
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001)
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
def _test_center_crop(self, device): def _test_center_crop(self, device):
script_center_crop = torch.jit.script(F.center_crop) script_center_crop = torch.jit.script(F.center_crop)
......
...@@ -13,7 +13,7 @@ from common_utils import TransformsTester ...@@ -13,7 +13,7 @@ from common_utils import TransformsTester
class Tester(TransformsTester): class Tester(TransformsTester):
def _test_functional_geom_op(self, func, fn_kwargs): def _test_functional_op(self, func, fn_kwargs):
if fn_kwargs is None: if fn_kwargs is None:
fn_kwargs = {} fn_kwargs = {}
tensor, pil_img = self._create_data(height=10, width=10) tensor, pil_img = self._create_data(height=10, width=10)
...@@ -21,7 +21,7 @@ class Tester(TransformsTester): ...@@ -21,7 +21,7 @@ class Tester(TransformsTester):
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs) transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img) self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
def _test_class_geom_op(self, method, meth_kwargs=None): def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None: if meth_kwargs is None:
meth_kwargs = {} meth_kwargs = {}
...@@ -35,21 +35,24 @@ class Tester(TransformsTester): ...@@ -35,21 +35,24 @@ class Tester(TransformsTester):
transformed_tensor = f(tensor) transformed_tensor = f(tensor)
torch.manual_seed(12) torch.manual_seed(12)
transformed_pil_img = f(pil_img) transformed_pil_img = f(pil_img)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img) if test_exact_match:
self.compareTensorToPIL(transformed_tensor, transformed_pil_img, **match_kwargs)
else:
self.approxEqualTensorToPIL(transformed_tensor.float(), transformed_pil_img, **match_kwargs)
torch.manual_seed(12) torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor) transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script)) self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None): def _test_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
self._test_functional_geom_op(func, fn_kwargs) self._test_functional_op(func, fn_kwargs)
self._test_class_geom_op(method, meth_kwargs) self._test_class_op(method, meth_kwargs)
def test_random_horizontal_flip(self): def test_random_horizontal_flip(self):
self._test_geom_op('hflip', 'RandomHorizontalFlip') self._test_op('hflip', 'RandomHorizontalFlip')
def test_random_vertical_flip(self): def test_random_vertical_flip(self):
self._test_geom_op('vflip', 'RandomVerticalFlip') self._test_op('vflip', 'RandomVerticalFlip')
def test_adjustments(self): def test_adjustments(self):
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation'] fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
...@@ -80,22 +83,22 @@ class Tester(TransformsTester): ...@@ -80,22 +83,22 @@ class Tester(TransformsTester):
def test_pad(self): def test_pad(self):
# Test functional.pad (PIL and Tensor) with padding as single int # Test functional.pad (PIL and Tensor) with padding as single int
self._test_functional_geom_op( self._test_functional_op(
"pad", fn_kwargs={"padding": 2, "fill": 0, "padding_mode": "constant"} "pad", fn_kwargs={"padding": 2, "fill": 0, "padding_mode": "constant"}
) )
# Test functional.pad and transforms.Pad with padding as [int, ] # Test functional.pad and transforms.Pad with padding as [int, ]
fn_kwargs = meth_kwargs = {"padding": [2, ], "fill": 0, "padding_mode": "constant"} fn_kwargs = meth_kwargs = {"padding": [2, ], "fill": 0, "padding_mode": "constant"}
self._test_geom_op( self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
# Test functional.pad and transforms.Pad with padding as list # Test functional.pad and transforms.Pad with padding as list
fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"} fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"}
self._test_geom_op( self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
# Test functional.pad and transforms.Pad with padding as tuple # Test functional.pad and transforms.Pad with padding as tuple
fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": "constant"} fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": "constant"}
self._test_geom_op( self._test_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
...@@ -103,7 +106,7 @@ class Tester(TransformsTester): ...@@ -103,7 +106,7 @@ class Tester(TransformsTester):
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5} fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
# Test transforms.RandomCrop with size and padding as tuple # Test transforms.RandomCrop with size and padding as tuple
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, } meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
self._test_geom_op( self._test_op(
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs 'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
...@@ -120,17 +123,17 @@ class Tester(TransformsTester): ...@@ -120,17 +123,17 @@ class Tester(TransformsTester):
for padding_config in padding_configs: for padding_config in padding_configs:
config = dict(padding_config) config = dict(padding_config)
config["size"] = size config["size"] = size
self._test_class_geom_op("RandomCrop", config) self._test_class_op("RandomCrop", config)
def test_center_crop(self): def test_center_crop(self):
fn_kwargs = {"output_size": (4, 5)} fn_kwargs = {"output_size": (4, 5)}
meth_kwargs = {"size": (4, 5), } meth_kwargs = {"size": (4, 5), }
self._test_geom_op( self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
fn_kwargs = {"output_size": (5,)} fn_kwargs = {"output_size": (5,)}
meth_kwargs = {"size": (5, )} meth_kwargs = {"size": (5, )}
self._test_geom_op( self._test_op(
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8) tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
...@@ -149,7 +152,7 @@ class Tester(TransformsTester): ...@@ -149,7 +152,7 @@ class Tester(TransformsTester):
scripted_fn = torch.jit.script(f) scripted_fn = torch.jit.script(f)
scripted_fn(tensor) scripted_fn(tensor)
def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None): def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None: if fn_kwargs is None:
fn_kwargs = {} fn_kwargs = {}
if meth_kwargs is None: if meth_kwargs is None:
...@@ -178,37 +181,37 @@ class Tester(TransformsTester): ...@@ -178,37 +181,37 @@ class Tester(TransformsTester):
def test_five_crop(self): def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)} fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output( self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
fn_kwargs = meth_kwargs = {"size": [5, ]} fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output( self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
fn_kwargs = meth_kwargs = {"size": (4, 5)} fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output( self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
fn_kwargs = meth_kwargs = {"size": [4, 5]} fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output( self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
def test_ten_crop(self): def test_ten_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)} fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_geom_op_list_output( self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
fn_kwargs = meth_kwargs = {"size": [5, ]} fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_geom_op_list_output( self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
fn_kwargs = meth_kwargs = {"size": (4, 5)} fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_geom_op_list_output( self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
fn_kwargs = meth_kwargs = {"size": [4, 5]} fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_geom_op_list_output( self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
...@@ -312,6 +315,24 @@ class Tester(TransformsTester): ...@@ -312,6 +315,24 @@ class Tester(TransformsTester):
out2 = s_transform(tensor) out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2)) self.assertTrue(out1.equal(out2))
def test_to_grayscale(self):
meth_kwargs = {"num_output_channels": 1}
tol = 1.0 + 1e-10
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
meth_kwargs = {"num_output_channels": 3}
self._test_class_op(
"Grayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
meth_kwargs = {}
self._test_class_op(
"RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -32,6 +32,13 @@ def _get_image_size(img: Tensor) -> List[int]: ...@@ -32,6 +32,13 @@ def _get_image_size(img: Tensor) -> List[int]:
return F_pil._get_image_size(img) return F_pil._get_image_size(img)
def _get_image_num_channels(img: Tensor) -> int:
if isinstance(img, torch.Tensor):
return F_t._get_image_num_channels(img)
return F_pil._get_image_num_channels(img)
@torch.jit.unused @torch.jit.unused
def _is_numpy(img: Any) -> bool: def _is_numpy(img: Any) -> bool:
return isinstance(img, np.ndarray) return isinstance(img, np.ndarray)
...@@ -951,11 +958,13 @@ def affine( ...@@ -951,11 +958,13 @@ def affine(
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
@torch.jit.unused
def to_grayscale(img, num_output_channels=1): def to_grayscale(img, num_output_channels=1):
"""Convert image to grayscale version of image. """Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
Args: Args:
img (PIL Image): Image to be converted to grayscale. img (PIL Image): PIL Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns: Returns:
PIL Image: Grayscale version of the image. PIL Image: Grayscale version of the image.
...@@ -963,20 +972,35 @@ def to_grayscale(img, num_output_channels=1): ...@@ -963,20 +972,35 @@ def to_grayscale(img, num_output_channels=1):
if num_output_channels = 3 : returned image is 3 channel with r = g = b if num_output_channels = 3 : returned image is 3 channel with r = g = b
""" """
if not F_pil._is_pil_image(img): if isinstance(img, Image.Image):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return F_pil.to_grayscale(img, num_output_channels)
if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')
return img raise TypeError("Input should be PIL Image")
def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
"""Convert RGB image to grayscale version of image.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Note:
Please, note that this method supports only RGB images as input. For inputs in other color spaces,
please, consider using meth:`~torchvision.transforms.functional.to_grayscale` with PIL Image.
Args:
img (PIL Image or Tensor): RGB Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns:
PIL Image or Tensor: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel
if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not isinstance(img, torch.Tensor):
return F_pil.to_grayscale(img, num_output_channels)
return F_t.rgb_to_grayscale(img, num_output_channels)
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
......
...@@ -26,6 +26,13 @@ def _get_image_size(img: Any) -> List[int]: ...@@ -26,6 +26,13 @@ def _get_image_size(img: Any) -> List[int]:
raise TypeError("Unexpected type {}".format(type(img))) raise TypeError("Unexpected type {}".format(type(img)))
@torch.jit.unused
def _get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == 'L' else 3
raise TypeError("Unexpected type {}".format(type(img)))
@torch.jit.unused @torch.jit.unused
def hflip(img): def hflip(img):
"""Horizontally flip the given PIL Image. """Horizontally flip the given PIL Image.
...@@ -480,3 +487,33 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None) ...@@ -480,3 +487,33 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None)
opts = _parse_fill(fill, img, '5.0.0') opts = _parse_fill(fill, img, '5.0.0')
return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts) return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts)
@torch.jit.unused
def to_grayscale(img, num_output_channels):
"""Convert PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image.
Args:
img (PIL Image): Image to be converted to grayscale.
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns:
PIL Image: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel
if num_output_channels = 3 : returned image is 3 channel with r = g = b
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if num_output_channels == 1:
img = img.convert('L')
elif num_output_channels == 3:
img = img.convert('L')
np_img = np.array(img, dtype=np.uint8)
np_img = np.dstack([np_img, np_img, np_img])
img = Image.fromarray(np_img, 'RGB')
else:
raise ValueError('num_output_channels should be either 1 or 3')
return img
...@@ -18,6 +18,15 @@ def _get_image_size(img: Tensor) -> List[int]: ...@@ -18,6 +18,15 @@ def _get_image_size(img: Tensor) -> List[int]:
raise TypeError("Unexpected type {}".format(type(img))) raise TypeError("Unexpected type {}".format(type(img)))
def _get_image_num_channels(img: Tensor) -> int:
if img.ndim == 2:
return 1
elif img.ndim > 2:
return img.shape[-3]
raise TypeError("Unexpected type {}".format(type(img)))
def vflip(img: Tensor) -> Tensor: def vflip(img: Tensor) -> Tensor:
"""Vertically flip the given the Image Tensor. """Vertically flip the given the Image Tensor.
...@@ -67,22 +76,41 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: ...@@ -67,22 +76,41 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
return img[..., top:top + height, left:left + width] return img[..., top:top + height, left:left + width]
def rgb_to_grayscale(img: Tensor) -> Tensor: def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
"""Convert the given RGB Image Tensor to Grayscale. """Convert the given RGB Image Tensor to Grayscale.
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
is L = R * 0.2989 + G * 0.5870 + B * 0.1140 is L = R * 0.2989 + G * 0.5870 + B * 0.1140
Args: Args:
img (Tensor): Image to be converted to Grayscale in the form [C, H, W]. img (Tensor): Image to be converted to Grayscale in the form [C, H, W].
num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1.
Returns: Returns:
Tensor: Grayscale image. Tensor: Grayscale version of the image.
if num_output_channels = 1 : returned image is single channel
if num_output_channels = 3 : returned image is 3 channel with r = g = b
""" """
if img.shape[0] != 3: if img.ndim < 3:
raise TypeError('Input Image does not contain 3 Channels') raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
c = img.shape[-3]
if c != 3:
raise TypeError("Input image tensor should 3 channels, but found {}".format(c))
if num_output_channels not in (1, 3):
raise ValueError('num_output_channels should be either 1 or 3')
r, g, b = img.unbind(dim=-3)
# This implementation closely follows the TF one:
# https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
l_img = (0.2989 * r + 0.587 * g + 0.114 * b).to(img.dtype)
l_img = l_img.unsqueeze(dim=-3)
if num_output_channels == 3:
return l_img.expand(img.shape)
return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype) return l_img
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
...@@ -373,8 +401,8 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa ...@@ -373,8 +401,8 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255 bound = 1.0 if img1.is_floating_point() else 255.0
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
def _rgb2hsv(img): def _rgb2hsv(img):
......
...@@ -1354,8 +1354,11 @@ class RandomAffine(torch.nn.Module): ...@@ -1354,8 +1354,11 @@ class RandomAffine(torch.nn.Module):
return s.format(name=self.__class__.__name__, **d) return s.format(name=self.__class__.__name__, **d)
class Grayscale(object): class Grayscale(torch.nn.Module):
"""Convert image to grayscale. """Convert image to grayscale.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading
dimensions
Args: Args:
num_output_channels (int): (1 or 3) number of channels desired for output image num_output_channels (int): (1 or 3) number of channels desired for output image
...@@ -1368,30 +1371,34 @@ class Grayscale(object): ...@@ -1368,30 +1371,34 @@ class Grayscale(object):
""" """
def __init__(self, num_output_channels=1): def __init__(self, num_output_channels=1):
super().__init__()
self.num_output_channels = num_output_channels self.num_output_channels = num_output_channels
def __call__(self, img): def forward(self, img: Tensor) -> Tensor:
""" """
Args: Args:
img (PIL Image): Image to be converted to grayscale. img (PIL Image or Tensor): Image to be converted to grayscale.
Returns: Returns:
PIL Image: Randomly grayscaled image. PIL Image or Tensor: Grayscaled image.
""" """
return F.to_grayscale(img, num_output_channels=self.num_output_channels) return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels)
class RandomGrayscale(object): class RandomGrayscale(torch.nn.Module):
"""Randomly convert image to grayscale with a probability of p (default 0.1). """Randomly convert image to grayscale with a probability of p (default 0.1).
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., 3, H, W] shape, where ... means an arbitrary number of leading
dimensions
Args: Args:
p (float): probability that image should be converted to grayscale. p (float): probability that image should be converted to grayscale.
Returns: Returns:
PIL Image: Grayscale version of the input image with probability p and unchanged PIL Image or Tensor: Grayscale version of the input image with probability p and unchanged
with probability (1-p). with probability (1-p).
- If input image is 1 channel: grayscale version is 1 channel - If input image is 1 channel: grayscale version is 1 channel
- If input image is 3 channel: grayscale version is 3 channel with r == g == b - If input image is 3 channel: grayscale version is 3 channel with r == g == b
...@@ -1399,19 +1406,20 @@ class RandomGrayscale(object): ...@@ -1399,19 +1406,20 @@ class RandomGrayscale(object):
""" """
def __init__(self, p=0.1): def __init__(self, p=0.1):
super().__init__()
self.p = p self.p = p
def __call__(self, img): def forward(self, img: Tensor) -> Tensor:
""" """
Args: Args:
img (PIL Image): Image to be converted to grayscale. img (PIL Image or Tensor): Image to be converted to grayscale.
Returns: Returns:
PIL Image: Randomly grayscaled image. PIL Image or Tensor: Randomly grayscaled image.
""" """
num_output_channels = 1 if img.mode == 'L' else 3 num_output_channels = F._get_image_num_channels(img)
if random.random() < self.p: if torch.rand(1) < self.p:
return F.to_grayscale(img, num_output_channels=num_output_channels) return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
return img return img
def __repr__(self): def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment