Unverified Commit 96f6e0a1 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Make get_image_size and get_image_num_channels public. (#4321)

parent 37a9ee5b
...@@ -33,7 +33,7 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip): ...@@ -33,7 +33,7 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
if torch.rand(1) < self.p: if torch.rand(1) < self.p:
image = F.hflip(image) image = F.hflip(image)
if target is not None: if target is not None:
width, _ = F._get_image_size(image) width, _ = F.get_image_size(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]] target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target: if "masks" in target:
target["masks"] = target["masks"].flip(-1) target["masks"] = target["masks"].flip(-1)
...@@ -76,7 +76,7 @@ class RandomIoUCrop(nn.Module): ...@@ -76,7 +76,7 @@ class RandomIoUCrop(nn.Module):
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)
orig_w, orig_h = F._get_image_size(image) orig_w, orig_h = F.get_image_size(image)
while True: while True:
# sample an option # sample an option
...@@ -157,7 +157,7 @@ class RandomZoomOut(nn.Module): ...@@ -157,7 +157,7 @@ class RandomZoomOut(nn.Module):
if torch.rand(1) < self.p: if torch.rand(1) < self.p:
return image, target return image, target
orig_w, orig_h = F._get_image_size(image) orig_w, orig_h = F.get_image_size(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r) canvas_width = int(orig_w * r)
...@@ -226,7 +226,7 @@ class RandomPhotometricDistort(nn.Module): ...@@ -226,7 +226,7 @@ class RandomPhotometricDistort(nn.Module):
image = self._contrast(image) image = self._contrast(image)
if r[6] < self.p: if r[6] < self.p:
channels = F._get_image_num_channels(image) channels = F.get_image_num_channels(image)
permutation = torch.randperm(channels) permutation = torch.randperm(channels)
is_pil = F._is_pil_image(image) is_pil = F._is_pil_image(image)
......
...@@ -31,6 +31,24 @@ from typing import Dict, List, Sequence, Tuple ...@@ -31,6 +31,24 @@ from typing import Dict, List, Sequence, Tuple
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('fn', [F.get_image_size, F.get_image_num_channels])
def test_image_sizes(device, fn):
script_F = torch.jit.script(fn)
img_tensor, pil_img = _create_data(16, 18, 3, device=device)
value_img = fn(img_tensor)
value_pil_img = fn(pil_img)
assert value_img == value_pil_img
value_img_script = script_F(img_tensor)
assert value_img == value_img_script
batch_tensors = _create_data_batch(16, 18, 3, num_samples=4, device=device)
value_img_batch = fn(batch_tensors)
assert value_img == value_img_batch
@needs_cuda @needs_cuda
def test_scale_channel(): def test_scale_channel():
"""Make sure that _scale_channel gives the same results on CPU and GPU as """Make sure that _scale_channel gives the same results on CPU and GPU as
...@@ -908,7 +926,7 @@ def test_resized_crop(device, mode): ...@@ -908,7 +926,7 @@ def test_resized_crop(device, mode):
@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('func, args', [ @pytest.mark.parametrize('func, args', [
(F_t._get_image_size, ()), (F_t.vflip, ()), (F_t.get_image_size, ()), (F_t.vflip, ()),
(F_t.hflip, ()), (F_t.crop, (1, 2, 4, 5)), (F_t.hflip, ()), (F_t.crop, (1, 2, 4, 5)),
(F_t.adjust_brightness, (0., )), (F_t.adjust_contrast, (1., )), (F_t.adjust_brightness, (0., )), (F_t.adjust_contrast, (1., )),
(F_t.adjust_hue, (-0.5, )), (F_t.adjust_saturation, (2., )), (F_t.adjust_hue, (-0.5, )), (F_t.adjust_saturation, (2., )),
......
...@@ -188,7 +188,7 @@ class AutoAugment(torch.nn.Module): ...@@ -188,7 +188,7 @@ class AutoAugment(torch.nn.Module):
fill = self.fill fill = self.fill
if isinstance(img, Tensor): if isinstance(img, Tensor):
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img) fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None: elif fill is not None:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
...@@ -209,10 +209,10 @@ class AutoAugment(torch.nn.Module): ...@@ -209,10 +209,10 @@ class AutoAugment(torch.nn.Module):
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
interpolation=self.interpolation, fill=fill) interpolation=self.interpolation, fill=fill)
elif op_name == "TranslateX": elif op_name == "TranslateX":
img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0, img = F.affine(img, angle=0.0, translate=[int(F.get_image_size(img)[0] * magnitude), 0], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "TranslateY": elif op_name == "TranslateY":
img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0, img = F.affine(img, angle=0.0, translate=[0, int(F.get_image_size(img)[1] * magnitude)], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "Rotate": elif op_name == "Rotate":
img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill) img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill)
......
...@@ -58,22 +58,34 @@ pil_modes_mapping = { ...@@ -58,22 +58,34 @@ pil_modes_mapping = {
_is_pil_image = F_pil._is_pil_image _is_pil_image = F_pil._is_pil_image
def _get_image_size(img: Tensor) -> List[int]: def get_image_size(img: Tensor) -> List[int]:
"""Returns image size as [w, h] """Returns the size of an image as [width, height].
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
List[int]: The image size.
""" """
if isinstance(img, torch.Tensor): if isinstance(img, torch.Tensor):
return F_t._get_image_size(img) return F_t.get_image_size(img)
return F_pil._get_image_size(img) return F_pil.get_image_size(img)
def _get_image_num_channels(img: Tensor) -> int: def get_image_num_channels(img: Tensor) -> int:
"""Returns number of image channels """Returns the number of channels of an image.
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
int: The number of channels.
""" """
if isinstance(img, torch.Tensor): if isinstance(img, torch.Tensor):
return F_t._get_image_num_channels(img) return F_t.get_image_num_channels(img)
return F_pil._get_image_num_channels(img) return F_pil.get_image_num_channels(img)
@torch.jit.unused @torch.jit.unused
...@@ -500,7 +512,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: ...@@ -500,7 +512,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0]) output_size = (output_size[0], output_size[0])
image_width, image_height = _get_image_size(img) image_width, image_height = get_image_size(img)
crop_height, crop_width = output_size crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
...@@ -511,7 +523,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: ...@@ -511,7 +523,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0, (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
] ]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0 img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
image_width, image_height = _get_image_size(img) image_width, image_height = get_image_size(img)
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return img return img
...@@ -696,7 +708,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten ...@@ -696,7 +708,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
if len(size) != 2: if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.") raise ValueError("Please provide only two dimensions (h, w) for size.")
image_width, image_height = _get_image_size(img) image_width, image_height = get_image_size(img)
crop_height, crop_width = size crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}" msg = "Requested crop size {} is bigger than input size {}"
...@@ -993,7 +1005,7 @@ def rotate( ...@@ -993,7 +1005,7 @@ def rotate(
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
img_size = _get_image_size(img) img_size = get_image_size(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]
...@@ -1094,7 +1106,7 @@ def affine( ...@@ -1094,7 +1106,7 @@ def affine(
if len(shear) != 2: if len(shear) != 2:
raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear)) raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear))
img_size = _get_image_size(img) img_size = get_image_size(img)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset # it is visually better to estimate the center without 0.5 offset
......
...@@ -20,14 +20,14 @@ def _is_pil_image(img: Any) -> bool: ...@@ -20,14 +20,14 @@ def _is_pil_image(img: Any) -> bool:
@torch.jit.unused @torch.jit.unused
def _get_image_size(img: Any) -> List[int]: def get_image_size(img: Any) -> List[int]:
if _is_pil_image(img): if _is_pil_image(img):
return img.size return list(img.size)
raise TypeError("Unexpected type {}".format(type(img))) raise TypeError("Unexpected type {}".format(type(img)))
@torch.jit.unused @torch.jit.unused
def _get_image_num_channels(img: Any) -> int: def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img): if _is_pil_image(img):
return 1 if img.mode == 'L' else 3 return 1 if img.mode == 'L' else 3
raise TypeError("Unexpected type {}".format(type(img))) raise TypeError("Unexpected type {}".format(type(img)))
......
...@@ -16,13 +16,13 @@ def _assert_image_tensor(img: Tensor) -> None: ...@@ -16,13 +16,13 @@ def _assert_image_tensor(img: Tensor) -> None:
raise TypeError("Tensor is not a torch image.") raise TypeError("Tensor is not a torch image.")
def _get_image_size(img: Tensor) -> List[int]: def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image # Returns (w, h) of tensor image
_assert_image_tensor(img) _assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]] return [img.shape[-1], img.shape[-2]]
def _get_image_num_channels(img: Tensor) -> int: def get_image_num_channels(img: Tensor) -> int:
if img.ndim == 2: if img.ndim == 2:
return 1 return 1
elif img.ndim > 2: elif img.ndim > 2:
...@@ -50,7 +50,7 @@ def _max_value(dtype: torch.dtype) -> float: ...@@ -50,7 +50,7 @@ def _max_value(dtype: torch.dtype) -> float:
def _assert_channels(img: Tensor, permitted: List[int]) -> None: def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = _get_image_num_channels(img) c = get_image_num_channels(img)
if c not in permitted: if c not in permitted:
raise TypeError("Input image tensor permitted channel values are {}, but found {}".format(permitted, c)) raise TypeError("Input image tensor permitted channel values are {}, but found {}".format(permitted, c))
...@@ -122,7 +122,7 @@ def hflip(img: Tensor) -> Tensor: ...@@ -122,7 +122,7 @@ def hflip(img: Tensor) -> Tensor:
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
w, h = _get_image_size(img) w, h = get_image_size(img)
right = left + width right = left + width
bottom = top + height bottom = top + height
...@@ -187,7 +187,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -187,7 +187,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
_assert_channels(img, [1, 3]) _assert_channels(img, [1, 3])
if _get_image_num_channels(img) == 1: # Match PIL behaviour if get_image_num_channels(img) == 1: # Match PIL behaviour
return img return img
orig_dtype = img.dtype orig_dtype = img.dtype
...@@ -513,7 +513,7 @@ def resize( ...@@ -513,7 +513,7 @@ def resize(
if antialias and interpolation not in ["bilinear", "bicubic"]: if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
w, h = _get_image_size(img) w, h = get_image_size(img)
if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w) short, long = (w, h) if w <= h else (h, w)
...@@ -586,7 +586,7 @@ def _assert_grid_transform_inputs( ...@@ -586,7 +586,7 @@ def _assert_grid_transform_inputs(
warnings.warn("Argument fill should be either int, float, tuple or list") warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill # Check fill
num_channels = _get_image_num_channels(img) num_channels = get_image_num_channels(img)
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = ("The number of elements in 'fill' cannot broadcast to match the number of " msg = ("The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})") "channels of the image ({} != {})")
......
...@@ -575,7 +575,7 @@ class RandomCrop(torch.nn.Module): ...@@ -575,7 +575,7 @@ class RandomCrop(torch.nn.Module):
Returns: Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
""" """
w, h = F._get_image_size(img) w, h = F.get_image_size(img)
th, tw = output_size th, tw = output_size
if h + 1 < th or w + 1 < tw: if h + 1 < th or w + 1 < tw:
...@@ -613,7 +613,7 @@ class RandomCrop(torch.nn.Module): ...@@ -613,7 +613,7 @@ class RandomCrop(torch.nn.Module):
if self.padding is not None: if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode) img = F.pad(img, self.padding, self.fill, self.padding_mode)
width, height = F._get_image_size(img) width, height = F.get_image_size(img)
# pad the width if needed # pad the width if needed
if self.pad_if_needed and width < self.size[1]: if self.pad_if_needed and width < self.size[1]:
padding = [self.size[1] - width, 0] padding = [self.size[1] - width, 0]
...@@ -742,12 +742,12 @@ class RandomPerspective(torch.nn.Module): ...@@ -742,12 +742,12 @@ class RandomPerspective(torch.nn.Module):
fill = self.fill fill = self.fill
if isinstance(img, Tensor): if isinstance(img, Tensor):
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img) fill = [float(fill)] * F.get_image_num_channels(img)
else: else:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
if torch.rand(1) < self.p: if torch.rand(1) < self.p:
width, height = F._get_image_size(img) width, height = F.get_image_size(img)
startpoints, endpoints = self.get_params(width, height, self.distortion_scale) startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation, fill) return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
return img return img
...@@ -858,7 +858,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -858,7 +858,7 @@ class RandomResizedCrop(torch.nn.Module):
tuple: params (i, j, h, w) to be passed to ``crop`` for a random tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop. sized crop.
""" """
width, height = F._get_image_size(img) width, height = F.get_image_size(img)
area = height * width area = height * width
log_ratio = torch.log(torch.tensor(ratio)) log_ratio = torch.log(torch.tensor(ratio))
...@@ -1280,7 +1280,7 @@ class RandomRotation(torch.nn.Module): ...@@ -1280,7 +1280,7 @@ class RandomRotation(torch.nn.Module):
fill = self.fill fill = self.fill
if isinstance(img, Tensor): if isinstance(img, Tensor):
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img) fill = [float(fill)] * F.get_image_num_channels(img)
else: else:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
angle = self.get_params(self.degrees) angle = self.get_params(self.degrees)
...@@ -1439,11 +1439,11 @@ class RandomAffine(torch.nn.Module): ...@@ -1439,11 +1439,11 @@ class RandomAffine(torch.nn.Module):
fill = self.fill fill = self.fill
if isinstance(img, Tensor): if isinstance(img, Tensor):
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img) fill = [float(fill)] * F.get_image_num_channels(img)
else: else:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
img_size = F._get_image_size(img) img_size = F.get_image_size(img)
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
...@@ -1529,7 +1529,7 @@ class RandomGrayscale(torch.nn.Module): ...@@ -1529,7 +1529,7 @@ class RandomGrayscale(torch.nn.Module):
Returns: Returns:
PIL Image or Tensor: Randomly grayscaled image. PIL Image or Tensor: Randomly grayscaled image.
""" """
num_output_channels = F._get_image_num_channels(img) num_output_channels = F.get_image_num_channels(img)
if torch.rand(1) < self.p: if torch.rand(1) < self.p:
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
return img return img
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment