Unverified Commit 095437aa authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Replace get_image_size/num_channels with get_dimensions (#5487)

* Replace get_image_size/num_channels with get_image_dims

* Reduce verbosity

* Fix JIT-scriptability

* Refactoring

* More refactoring

* Replace all _FP/_FT direct calls.

* Remove usages of get_image_size and get_image_num_channels from code-base.

* Fix JIT issues

* Adding missing assertion.
parent f40c8df0
......@@ -270,6 +270,7 @@ you can use a functional transform to build transform classes with custom behavi
erase
five_crop
gaussian_blur
get_dimensions
get_image_num_channels
get_image_size
hflip
......
......@@ -141,7 +141,7 @@ class RandomCutmix(torch.nn.Module):
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
W, H = F.get_image_size(batch)
_, H, W = F.get_dimensions(batch)
r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))
......
......@@ -34,7 +34,7 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
if torch.rand(1) < self.p:
image = F.hflip(image)
if target is not None:
width, _ = F.get_image_size(image)
_, _, width = F.get_dimensions(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
......@@ -107,7 +107,7 @@ class RandomIoUCrop(nn.Module):
elif image.ndimension() == 2:
image = image.unsqueeze(0)
orig_w, orig_h = F.get_image_size(image)
_, orig_h, orig_w = F.get_dimensions(image)
while True:
# sample an option
......@@ -192,7 +192,7 @@ class RandomZoomOut(nn.Module):
if torch.rand(1) >= self.p:
return image, target
orig_w, orig_h = F.get_image_size(image)
_, orig_h, orig_w = F.get_dimensions(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
......@@ -270,7 +270,7 @@ class RandomPhotometricDistort(nn.Module):
image = self._contrast(image)
if r[6] < self.p:
channels = F.get_image_num_channels(image)
channels, _, _ = F.get_dimensions(image)
permutation = torch.randperm(channels)
is_pil = F._is_pil_image(image)
......@@ -317,7 +317,7 @@ class ScaleJitter(nn.Module):
elif image.ndimension() == 2:
image = image.unsqueeze(0)
orig_width, orig_height = F.get_image_size(image)
_, orig_height, orig_width = F.get_dimensions(image)
r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
new_width = int(self.target_size[1] * r)
......
......@@ -29,7 +29,7 @@ NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINE
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels])
@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions])
def test_image_sizes(device, fn):
script_F = torch.jit.script(fn)
......@@ -1020,7 +1020,9 @@ def test_resized_crop(device, mode):
@pytest.mark.parametrize(
"func, args",
[
(F_t.get_dimensions, ()),
(F_t.get_image_size, ()),
(F_t.get_image_num_channels, ()),
(F_t.vflip, ()),
(F_t.hflip, ()),
(F_t.crop, (1, 2, 4, 5)),
......
......@@ -8,7 +8,7 @@ from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from ._type_conversion import DecodeImage, LabelToOneHot
......@@ -7,7 +7,7 @@ import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from ._utils import query_image
from ._utils import query_image, get_image_dimensions
class RandomErasing(Transform):
......@@ -41,8 +41,7 @@ class RandomErasing(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
img_c = F.get_image_num_channels(image)
img_w, img_h = F.get_image_size(image)
img_c, img_h, img_w = get_image_dimensions(image)
if isinstance(self.value, (int, float)):
value = [self.value]
......@@ -138,7 +137,7 @@ class RandomCutmix(Transform):
lam = float(self._dist.sample(()))
image = query_image(sample)
W, H = F.get_image_size(image)
_, H, W = get_image_dimensions(image)
r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
......
......@@ -7,7 +7,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F
from torchvision.prototype.utils._internal import apply_recursively
from ._utils import query_image
from ._utils import query_image, get_image_dimensions
K = TypeVar("K")
V = TypeVar("V")
......@@ -47,7 +47,7 @@ class _AutoAugmentBase(Transform):
return input
image = query_image(sample)
num_channels = F.get_image_num_channels(image)
num_channels, *_ = get_image_dimensions(image)
fill = self.fill
if isinstance(fill, (int, float)):
......@@ -160,8 +160,8 @@ class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
......@@ -278,7 +278,7 @@ class AutoAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0]
image = query_image(sample)
image_size = F.get_image_size(image)
_, height, width = get_image_dimensions(image)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
......@@ -288,7 +288,7 @@ class AutoAugment(_AutoAugmentBase):
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, image_size)
magnitudes = magnitudes_fn(10, (height, width))
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
......@@ -306,8 +306,8 @@ class RandAugment(_AutoAugmentBase):
"Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
......@@ -334,12 +334,12 @@ class RandAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0]
image = query_image(sample)
image_size = F.get_image_size(image)
_, height, width = get_image_dimensions(image)
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
......@@ -383,11 +383,11 @@ class TrivialAugmentWide(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0]
image = query_image(sample)
image_size = F.get_image_size(image)
_, height, width = get_image_dimensions(image)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
......
......@@ -8,7 +8,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from ._utils import query_image
from ._utils import query_image, get_image_dimensions
class HorizontalFlip(Transform):
......@@ -109,7 +109,7 @@ class RandomResizedCrop(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
width, height = F.get_image_size(image)
_, height, width = get_image_dimensions(image)
area = height * width
log_ratio = torch.log(torch.tensor(self.ratio))
......
from typing import Any, Optional, Union
from typing import Any, Optional, Tuple, Union
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.utils._internal import query_recursively
from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]:
......@@ -17,3 +19,16 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im
return next(query_recursively(fn, sample))
except StopIteration:
raise TypeError("No image was found in the sample")
def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
if isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
elif isinstance(image, torch.Tensor):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image):
channels, height, width = get_dimensions_image_pil(image)
else:
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
return channels, height, width
from torchvision.transforms import InterpolationMode # usort: skip
from ._utils import get_image_size, get_image_num_channels # usort: skip
from ._meta_conversion import (
from ._meta import (
convert_bounding_box_format,
convert_image_color_space_tensor,
convert_image_color_space_pil,
......
......@@ -5,11 +5,10 @@ import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import InterpolationMode
from torchvision.prototype.transforms.functional import get_image_size
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix
from ._meta_conversion import convert_bounding_box_format
from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil
horizontal_flip_image_tensor = _FT.hflip
......@@ -40,8 +39,7 @@ def resize_image_tensor(
antialias: Optional[bool] = None,
) -> torch.Tensor:
new_height, new_width = size
old_width, old_height = _FT.get_image_size(image)
num_channels = _FT.get_image_num_channels(image)
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
batch_shape = image.shape[:-3]
return _FT.resize(
image.reshape((-1, num_channels, old_height, old_width)),
......@@ -143,9 +141,9 @@ def affine_image_tensor(
center_f = [0.0, 0.0]
if center is not None:
width, height = get_image_size(img)
_, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
......@@ -169,7 +167,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
width, height = get_image_size(img)
_, height, width = get_dimensions_image_pil(img)
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
......@@ -186,9 +184,9 @@ def rotate_image_tensor(
) -> torch.Tensor:
center_f = [0.0, 0.0]
if center is not None:
width, height = get_image_size(img)
_, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
......@@ -262,13 +260,13 @@ def _center_crop_compute_crop_anchor(
def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img)
_, image_height, image_width = get_dimensions_image_tensor(img)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_tensor(img, padding_ltrb, fill=0)
image_width, image_height = get_image_size(img)
_, image_height, image_width = get_dimensions_image_tensor(img)
if crop_width == image_width and crop_height == image_height:
return img
......@@ -278,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch
def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img)
_, image_height, image_width = get_dimensions_image_pil(img)
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_pil(img, padding_ltrb, fill=0)
image_width, image_height = get_image_size(img)
_, image_height, image_width = get_dimensions_image_pil(img)
if crop_width == image_width and crop_height == image_height:
return img
......
......@@ -4,6 +4,10 @@ from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
xyxy = xywh.clone()
xyxy[..., 2:] += xyxy[..., :2]
......
from typing import Tuple, Union, cast
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]:
if isinstance(image, features.Image):
height, width = image.image_size
return width, height
elif isinstance(image, torch.Tensor):
return cast(Tuple[int, int], tuple(_FT.get_image_size(image)))
if isinstance(image, PIL.Image.Image):
return cast(Tuple[int, int], tuple(_FP.get_image_size(image)))
else:
raise TypeError(f"unable to get image size from object of type {type(image).__name__}")
def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int:
if isinstance(image, features.Image):
return image.num_channels
elif isinstance(image, torch.Tensor):
return _FT.get_image_num_channels(image)
if isinstance(image, PIL.Image.Image):
return cast(int, _FP.get_image_num_channels(image))
else:
raise TypeError(f"unable to get num channels from object of type {type(image).__name__}")
......@@ -220,13 +220,13 @@ class AutoAugment(torch.nn.Module):
else:
raise ValueError(f"The provided policy {policy} is not recognized.")
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
......@@ -260,15 +260,16 @@ class AutoAugment(torch.nn.Module):
PIL Image or Tensor: AutoAugmented image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
transform_id, probs, signs = self.get_params(len(self.policies))
op_meta = self._augmentation_space(10, F.get_image_size(img))
op_meta = self._augmentation_space(10, (height, width))
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
if probs[i] <= p:
magnitudes, signed = op_meta[op_name]
......@@ -317,14 +318,14 @@ class RandAugment(torch.nn.Module):
self.interpolation = interpolation
self.fill = fill
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"Identity": (torch.tensor(0.0), False),
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
......@@ -344,13 +345,14 @@ class RandAugment(torch.nn.Module):
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
for _ in range(self.num_ops):
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
......@@ -429,9 +431,10 @@ class TrivialAugmentWide(torch.nn.Module):
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
......@@ -503,13 +506,13 @@ class AugMix(torch.nn.Module):
self.interpolation = interpolation
self.fill = fill
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
s = {
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
"TranslateX": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
......@@ -547,16 +550,17 @@ class AugMix(torch.nn.Module):
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(orig_img)
if isinstance(orig_img, Tensor):
img = orig_img
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
else:
img = self._pil_to_tensor(orig_img)
op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img))
op_meta = self._augmentation_space(self._PARAMETER_MAX, (height, width))
orig_dims = list(img.shape)
batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
......
......@@ -59,6 +59,23 @@ pil_modes_mapping = {
_is_pil_image = F_pil._is_pil_image
def get_dimensions(img: Tensor) -> List[int]:
"""Returns the dimensions of an image as [channels, height, width].
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
List[int]: The image dimensions.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(get_dimensions)
if isinstance(img, torch.Tensor):
return F_t.get_dimensions(img)
return F_pil.get_dimensions(img)
def get_image_size(img: Tensor) -> List[int]:
"""Returns the size of an image as [width, height].
......@@ -512,7 +529,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
image_width, image_height = get_image_size(img)
_, image_height, image_width = get_dimensions(img)
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
......@@ -523,7 +540,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
image_width, image_height = get_image_size(img)
_, image_height, image_width = get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img
......@@ -721,7 +738,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
image_width, image_height = get_image_size(img)
_, image_height, image_width = get_dimensions(img)
crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
......@@ -1047,9 +1064,9 @@ def rotate(
center_f = [0.0, 0.0]
if center is not None:
img_size = get_image_size(img)
_, height, width = get_dimensions(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
......@@ -1167,22 +1184,22 @@ def affine(
if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
img_size = get_image_size(img)
_, height, width = get_dimensions(img)
if not isinstance(img, torch.Tensor):
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# center = (width * 0.5 + 0.5, height * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
center = [img_size[0] * 0.5, img_size[1] * 0.5]
center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
center_f = [0.0, 0.0]
if center is not None:
img_size = get_image_size(img)
_, height, width = get_dimensions(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
......
......@@ -20,6 +20,15 @@ def _is_pil_image(img: Any) -> bool:
return isinstance(img, Image.Image)
@torch.jit.unused
def get_dimensions(img: Any) -> List[int]:
if _is_pil_image(img):
channels = len(img.getbands())
width, height = img.size
return [channels, height, width]
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused
def get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
......@@ -30,7 +39,7 @@ def get_image_size(img: Any) -> List[int]:
@torch.jit.unused
def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == "L" else 3
return len(img.getbands())
raise TypeError(f"Unexpected type {type(img)}")
......
......@@ -21,6 +21,13 @@ def _assert_threshold(img: Tensor, threshold: float) -> None:
raise TypeError("Threshold should be less than bound of img.")
def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
height, width = img.shape[-2:]
return [channels, height, width]
def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
......@@ -28,6 +35,7 @@ def get_image_size(img: Tensor) -> List[int]:
def get_image_num_channels(img: Tensor) -> int:
_assert_image_tensor(img)
if img.ndim == 2:
return 1
elif img.ndim > 2:
......@@ -55,7 +63,7 @@ def _max_value(dtype: torch.dtype) -> float:
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = get_image_num_channels(img)
c = get_dimensions(img)[0]
if c not in permitted:
raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
......@@ -127,7 +135,7 @@ def hflip(img: Tensor) -> Tensor:
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)
w, h = get_image_size(img)
_, h, w = get_dimensions(img)
right = left + width
bottom = top + height
......@@ -175,7 +183,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
_assert_image_tensor(img)
_assert_channels(img, [3, 1])
c = get_image_num_channels(img)
c = get_dimensions(img)[0]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
if c == 3:
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
......@@ -195,7 +203,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
_assert_image_tensor(img)
_assert_channels(img, [1, 3])
if get_image_num_channels(img) == 1: # Match PIL behaviour
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
orig_dtype = img.dtype
......@@ -222,7 +230,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
_assert_channels(img, [1, 3])
if get_image_num_channels(img) == 1: # Match PIL behaviour
if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img
return _blend(img, rgb_to_grayscale(img), saturation_factor)
......@@ -451,7 +459,7 @@ def resize(
if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
w, h = get_image_size(img)
_, h, w = get_dimensions(img)
if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w)
......@@ -518,7 +526,7 @@ def _assert_grid_transform_inputs(
warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill
num_channels = get_image_num_channels(img)
num_channels = get_dimensions(img)[0]
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = (
"The number of elements in 'fill' cannot broadcast to match the number of "
......
......@@ -628,7 +628,7 @@ class RandomCrop(torch.nn.Module):
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = F.get_image_size(img)
_, h, w = F.get_dimensions(img)
th, tw = output_size
if h + 1 < th or w + 1 < tw:
......@@ -663,7 +663,7 @@ class RandomCrop(torch.nn.Module):
if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)
width, height = F.get_image_size(img)
_, height, width = F.get_dimensions(img)
# pad the width if needed
if self.pad_if_needed and width < self.size[1]:
padding = [self.size[1] - width, 0]
......@@ -793,14 +793,14 @@ class RandomPerspective(torch.nn.Module):
"""
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
fill = [float(fill)] * channels
else:
fill = [float(f) for f in fill]
if torch.rand(1) < self.p:
width, height = F.get_image_size(img)
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
return img
......@@ -910,7 +910,7 @@ class RandomResizedCrop(torch.nn.Module):
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
width, height = F.get_image_size(img)
_, height, width = F.get_dimensions(img)
area = height * width
log_ratio = torch.log(torch.tensor(ratio))
......@@ -1339,9 +1339,10 @@ class RandomRotation(torch.nn.Module):
PIL Image or Tensor: Rotated image.
"""
fill = self.fill
channels, _, _ = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
fill = [float(fill)] * channels
else:
fill = [float(f) for f in fill]
angle = self.get_params(self.degrees)
......@@ -1519,13 +1520,14 @@ class RandomAffine(torch.nn.Module):
PIL Image or Tensor: Affine transformed image.
"""
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
fill = [float(fill)] * channels
else:
fill = [float(f) for f in fill]
img_size = F.get_image_size(img)
img_size = [width, height] # flip for keeping BC on get_params call
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
......@@ -1608,7 +1610,7 @@ class RandomGrayscale(torch.nn.Module):
Returns:
PIL Image or Tensor: Randomly grayscaled image.
"""
num_output_channels = F.get_image_num_channels(img)
num_output_channels, _, _ = F.get_dimensions(img)
if torch.rand(1) < self.p:
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
return img
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment