Unverified Commit 1a300d84 authored by Avijit Dasgupta's avatar Avijit Dasgupta Committed by GitHub
Browse files

Cleanup functional_tensor.py (#3159) (#3171)



* added the helper method for dimension checks

* unit tests for dimensio check function in functional_tensor

* code formatting and typing

* moved torch image check after tensor check

* unit testcases for test_assert_image_tensor added and refactored

* separate unit testcase file deleted

* assert_image_tensor added to newly created 6 methods

* test cases added for new 6 mthohds

* removed wrongly pasted posterize method and added solarize method for testing
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 90645ccd
......@@ -13,6 +13,8 @@ from torchvision.transforms import InterpolationMode
from common_utils import TransformsTester
from typing import Dict, List, Tuple
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
......@@ -34,6 +36,28 @@ class Tester(TransformsTester):
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
self.assertTrue(transformed_batch.allclose(s_transformed_batch))
def test_assert_image_tensor(self):
shape = (100,)
tensor = torch.rand(*shape, dtype=torch.float, device=self.device)
list_of_methods = [(F_t._get_image_size, (tensor, )), (F_t.vflip, (tensor, )),
(F_t.hflip, (tensor, )), (F_t.crop, (tensor, 1, 2, 4, 5)),
(F_t.adjust_brightness, (tensor, 0.)), (F_t.adjust_contrast, (tensor, 1.)),
(F_t.adjust_hue, (tensor, -0.5)), (F_t.adjust_saturation, (tensor, 2.)),
(F_t.center_crop, (tensor, [10, 11])), (F_t.five_crop, (tensor, [10, 11])),
(F_t.ten_crop, (tensor, [10, 11])), (F_t.pad, (tensor, [2, ], 2, "constant")),
(F_t.resize, (tensor, [10, 11])), (F_t.perspective, (tensor, [0.2, ])),
(F_t.gaussian_blur, (tensor, (2, 2), (0.7, 0.5))),
(F_t.invert, (tensor, )), (F_t.posterize, (tensor, 0)),
(F_t.solarize, (tensor, 0.3)), (F_t.adjust_sharpness, (tensor, 0.3)),
(F_t.autocontrast, (tensor, )), (F_t.equalize, (tensor, ))]
for func, args in list_of_methods:
with self.assertRaises(Exception) as context:
func(*args)
self.assertTrue('Tensor is not a torch image.' in str(context.exception))
def test_vflip(self):
script_vflip = torch.jit.script(F.vflip)
......
......@@ -11,11 +11,15 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool:
return x.ndim >= 2
def _assert_image_tensor(img):
if not _is_tensor_a_torch_image(img):
raise TypeError("Tensor is not a torch image.")
def _get_image_size(img: Tensor) -> List[int]:
"""Returns (w, h) of tensor image"""
if _is_tensor_a_torch_image(img):
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]
raise TypeError("Unexpected input type")
def _get_image_num_channels(img: Tensor) -> int:
......@@ -143,8 +147,7 @@ def vflip(img: Tensor) -> Tensor:
Returns:
Tensor: Vertically flipped image Tensor.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
return img.flip(-2)
......@@ -163,8 +166,7 @@ def hflip(img: Tensor) -> Tensor:
Returns:
Tensor: Horizontally flipped image Tensor.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
return img.flip(-1)
......@@ -187,8 +189,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
Returns:
Tensor: Cropped image.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")
_assert_image_tensor(img)
return img[..., top:top + height, left:left + width]
......@@ -254,8 +255,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor))
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
......@@ -282,8 +282,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor))
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
_assert_channels(img, [3])
......@@ -326,9 +325,11 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
if not (isinstance(img, torch.Tensor)):
raise TypeError('Input img should be Tensor image')
_assert_image_tensor(img)
_assert_channels(img, [3])
orig_dtype = img.dtype
......@@ -367,8 +368,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor))
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
_assert_channels(img, [3])
......@@ -447,8 +447,7 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
"Please, use ``F.center_crop`` instead."
)
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
_, image_width, image_height = img.size()
crop_height, crop_width = output_size
......@@ -497,8 +496,7 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
"Please, use ``F.five_crop`` instead."
)
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
......@@ -553,8 +551,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa
"Please, use ``F.ten_crop`` instead."
)
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
first_five = five_crop(img, size)
......@@ -703,8 +700,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
Returns:
Tensor: Padded image.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")
_assert_image_tensor(img)
if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
......@@ -796,8 +792,7 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
Returns:
Tensor: Resized image.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")
_assert_image_tensor(img)
if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg")
......@@ -855,8 +850,11 @@ def _assert_grid_transform_inputs(
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
):
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError("Input img should be Tensor Image")
if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor")
_assert_image_tensor(img)
if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
......@@ -1112,8 +1110,11 @@ def perspective(
Returns:
Tensor: transformed image.
"""
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('Input img should be Tensor Image')
if not (isinstance(img, torch.Tensor)):
raise TypeError('Input img should be Tensor.')
_assert_image_tensor(img)
_assert_grid_transform_inputs(
img,
......@@ -1165,8 +1166,11 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
Returns:
Tensor: An image that is blurred using gaussian kernel of given parameters
"""
if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
if not (isinstance(img, torch.Tensor)):
raise TypeError('img should be Tensor. Got {}'.format(type(img)))
_assert_image_tensor(img)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
......@@ -1184,8 +1188,8 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
def invert(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
......@@ -1197,8 +1201,8 @@ def invert(img: Tensor) -> Tensor:
def posterize(img: Tensor, bits: int) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
......@@ -1211,8 +1215,8 @@ def posterize(img: Tensor, bits: int) -> Tensor:
def solarize(img: Tensor, threshold: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
......@@ -1245,8 +1249,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor))
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
......@@ -1257,8 +1260,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
def autocontrast(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
......@@ -1297,8 +1300,8 @@ def _equalize_single_image(img: Tensor) -> Tensor:
def equalize(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)
if not (3 <= img.ndim <= 4):
raise TypeError("Input image tensor should have 3 or 4 dimensions, but found {}".format(img.ndim))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment