Unverified Commit 095437aa authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Replace get_image_size/num_channels with get_dimensions (#5487)

* Replace get_image_size/num_channels with get_image_dims

* Reduce verbosity

* Fix JIT-scriptability

* Refactoring

* More refactoring

* Replace all _FP/_FT direct calls.

* Remove usages of get_image_size and get_image_num_channels from code-base.

* Fix JIT issues

* Adding missing assertion.
parent f40c8df0
...@@ -270,6 +270,7 @@ you can use a functional transform to build transform classes with custom behavi ...@@ -270,6 +270,7 @@ you can use a functional transform to build transform classes with custom behavi
erase erase
five_crop five_crop
gaussian_blur gaussian_blur
get_dimensions
get_image_num_channels get_image_num_channels
get_image_size get_image_size
hflip hflip
......
...@@ -141,7 +141,7 @@ class RandomCutmix(torch.nn.Module): ...@@ -141,7 +141,7 @@ class RandomCutmix(torch.nn.Module):
# Implemented as on cutmix paper, page 12 (with minor corrections on typos). # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
W, H = F.get_image_size(batch) _, H, W = F.get_dimensions(batch)
r_x = torch.randint(W, (1,)) r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,)) r_y = torch.randint(H, (1,))
......
...@@ -34,7 +34,7 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip): ...@@ -34,7 +34,7 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
if torch.rand(1) < self.p: if torch.rand(1) < self.p:
image = F.hflip(image) image = F.hflip(image)
if target is not None: if target is not None:
width, _ = F.get_image_size(image) _, _, width = F.get_dimensions(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]] target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target: if "masks" in target:
target["masks"] = target["masks"].flip(-1) target["masks"] = target["masks"].flip(-1)
...@@ -107,7 +107,7 @@ class RandomIoUCrop(nn.Module): ...@@ -107,7 +107,7 @@ class RandomIoUCrop(nn.Module):
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)
orig_w, orig_h = F.get_image_size(image) _, orig_h, orig_w = F.get_dimensions(image)
while True: while True:
# sample an option # sample an option
...@@ -192,7 +192,7 @@ class RandomZoomOut(nn.Module): ...@@ -192,7 +192,7 @@ class RandomZoomOut(nn.Module):
if torch.rand(1) >= self.p: if torch.rand(1) >= self.p:
return image, target return image, target
orig_w, orig_h = F.get_image_size(image) _, orig_h, orig_w = F.get_dimensions(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r) canvas_width = int(orig_w * r)
...@@ -270,7 +270,7 @@ class RandomPhotometricDistort(nn.Module): ...@@ -270,7 +270,7 @@ class RandomPhotometricDistort(nn.Module):
image = self._contrast(image) image = self._contrast(image)
if r[6] < self.p: if r[6] < self.p:
channels = F.get_image_num_channels(image) channels, _, _ = F.get_dimensions(image)
permutation = torch.randperm(channels) permutation = torch.randperm(channels)
is_pil = F._is_pil_image(image) is_pil = F._is_pil_image(image)
...@@ -317,7 +317,7 @@ class ScaleJitter(nn.Module): ...@@ -317,7 +317,7 @@ class ScaleJitter(nn.Module):
elif image.ndimension() == 2: elif image.ndimension() == 2:
image = image.unsqueeze(0) image = image.unsqueeze(0)
orig_width, orig_height = F.get_image_size(image) _, orig_height, orig_width = F.get_dimensions(image)
r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
new_width = int(self.target_size[1] * r) new_width = int(self.target_size[1] * r)
......
...@@ -29,7 +29,7 @@ NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINE ...@@ -29,7 +29,7 @@ NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINE
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels]) @pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions])
def test_image_sizes(device, fn): def test_image_sizes(device, fn):
script_F = torch.jit.script(fn) script_F = torch.jit.script(fn)
...@@ -1020,7 +1020,9 @@ def test_resized_crop(device, mode): ...@@ -1020,7 +1020,9 @@ def test_resized_crop(device, mode):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"func, args", "func, args",
[ [
(F_t.get_dimensions, ()),
(F_t.get_image_size, ()), (F_t.get_image_size, ()),
(F_t.get_image_num_channels, ()),
(F_t.vflip, ()), (F_t.vflip, ()),
(F_t.hflip, ()), (F_t.hflip, ()),
(F_t.crop, (1, 2, 4, 5)), (F_t.crop, (1, 2, 4, 5)),
......
...@@ -8,7 +8,7 @@ from ._augment import RandomErasing, RandomMixup, RandomCutmix ...@@ -8,7 +8,7 @@ from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from ._type_conversion import DecodeImage, LabelToOneHot from ._type_conversion import DecodeImage, LabelToOneHot
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F from torchvision.prototype.transforms import Transform, functional as F
from ._utils import query_image from ._utils import query_image, get_image_dimensions
class RandomErasing(Transform): class RandomErasing(Transform):
...@@ -41,8 +41,7 @@ class RandomErasing(Transform): ...@@ -41,8 +41,7 @@ class RandomErasing(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) image = query_image(sample)
img_c = F.get_image_num_channels(image) img_c, img_h, img_w = get_image_dimensions(image)
img_w, img_h = F.get_image_size(image)
if isinstance(self.value, (int, float)): if isinstance(self.value, (int, float)):
value = [self.value] value = [self.value]
...@@ -138,7 +137,7 @@ class RandomCutmix(Transform): ...@@ -138,7 +137,7 @@ class RandomCutmix(Transform):
lam = float(self._dist.sample(())) lam = float(self._dist.sample(()))
image = query_image(sample) image = query_image(sample)
W, H = F.get_image_size(image) _, H, W = get_image_dimensions(image)
r_x = torch.randint(W, ()) r_x = torch.randint(W, ())
r_y = torch.randint(H, ()) r_y = torch.randint(H, ())
......
...@@ -7,7 +7,7 @@ from torchvision.prototype import features ...@@ -7,7 +7,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F
from torchvision.prototype.utils._internal import apply_recursively from torchvision.prototype.utils._internal import apply_recursively
from ._utils import query_image from ._utils import query_image, get_image_dimensions
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
...@@ -47,7 +47,7 @@ class _AutoAugmentBase(Transform): ...@@ -47,7 +47,7 @@ class _AutoAugmentBase(Transform):
return input return input
image = query_image(sample) image = query_image(sample)
num_channels = F.get_image_num_channels(image) num_channels, *_ = get_image_dimensions(image)
fill = self.fill fill = self.fill
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
...@@ -160,8 +160,8 @@ class AutoAugment(_AutoAugmentBase): ...@@ -160,8 +160,8 @@ class AutoAugment(_AutoAugmentBase):
_AUGMENTATION_SPACE = { _AUGMENTATION_SPACE = {
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
...@@ -278,7 +278,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -278,7 +278,7 @@ class AutoAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
image = query_image(sample) image = query_image(sample)
image_size = F.get_image_size(image) _, height, width = get_image_dimensions(image)
policy = self._policies[int(torch.randint(len(self._policies), ()))] policy = self._policies[int(torch.randint(len(self._policies), ()))]
...@@ -288,7 +288,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -288,7 +288,7 @@ class AutoAugment(_AutoAugmentBase):
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
magnitudes = magnitudes_fn(10, image_size) magnitudes = magnitudes_fn(10, (height, width))
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx]) magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -306,8 +306,8 @@ class RandAugment(_AutoAugmentBase): ...@@ -306,8 +306,8 @@ class RandAugment(_AutoAugmentBase):
"Identity": (lambda num_bins, image_size: None, False), "Identity": (lambda num_bins, image_size: None, False),
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
...@@ -334,12 +334,12 @@ class RandAugment(_AutoAugmentBase): ...@@ -334,12 +334,12 @@ class RandAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
image = query_image(sample) image = query_image(sample)
image_size = F.get_image_size(image) _, height, width = get_image_dimensions(image)
for _ in range(self.num_ops): for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size) magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
...@@ -383,11 +383,11 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -383,11 +383,11 @@ class TrivialAugmentWide(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
image = query_image(sample) image = query_image(sample)
image_size = F.get_image_size(image) _, height, width = get_image_dimensions(image)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size) magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
if magnitudes is not None: if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5: if signed and torch.rand(()) <= 0.5:
......
...@@ -8,7 +8,7 @@ from torchvision.prototype import features ...@@ -8,7 +8,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from ._utils import query_image from ._utils import query_image, get_image_dimensions
class HorizontalFlip(Transform): class HorizontalFlip(Transform):
...@@ -109,7 +109,7 @@ class RandomResizedCrop(Transform): ...@@ -109,7 +109,7 @@ class RandomResizedCrop(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) image = query_image(sample)
width, height = F.get_image_size(image) _, height, width = get_image_dimensions(image)
area = height * width area = height * width
log_ratio = torch.log(torch.tensor(self.ratio)) log_ratio = torch.log(torch.tensor(self.ratio))
......
from typing import Any, Optional, Union from typing import Any, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.utils._internal import query_recursively from torchvision.prototype.utils._internal import query_recursively
from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]: def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]:
...@@ -17,3 +19,16 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im ...@@ -17,3 +19,16 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im
return next(query_recursively(fn, sample)) return next(query_recursively(fn, sample))
except StopIteration: except StopIteration:
raise TypeError("No image was found in the sample") raise TypeError("No image was found in the sample")
def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
if isinstance(image, features.Image):
channels = image.num_channels
height, width = image.image_size
elif isinstance(image, torch.Tensor):
channels, height, width = get_dimensions_image_tensor(image)
elif isinstance(image, PIL.Image.Image):
channels, height, width = get_dimensions_image_pil(image)
else:
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
return channels, height, width
from torchvision.transforms import InterpolationMode # usort: skip from torchvision.transforms import InterpolationMode # usort: skip
from ._utils import get_image_size, get_image_num_channels # usort: skip from ._meta import (
from ._meta_conversion import (
convert_bounding_box_format, convert_bounding_box_format,
convert_image_color_space_tensor, convert_image_color_space_tensor,
convert_image_color_space_pil, convert_image_color_space_pil,
......
...@@ -5,11 +5,10 @@ import PIL.Image ...@@ -5,11 +5,10 @@ import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import InterpolationMode from torchvision.prototype.transforms import InterpolationMode
from torchvision.prototype.transforms.functional import get_image_size
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix
from ._meta_conversion import convert_bounding_box_format from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil
horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_tensor = _FT.hflip
...@@ -40,8 +39,7 @@ def resize_image_tensor( ...@@ -40,8 +39,7 @@ def resize_image_tensor(
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
new_height, new_width = size new_height, new_width = size
old_width, old_height = _FT.get_image_size(image) num_channels, old_height, old_width = get_dimensions_image_tensor(image)
num_channels = _FT.get_image_num_channels(image)
batch_shape = image.shape[:-3] batch_shape = image.shape[:-3]
return _FT.resize( return _FT.resize(
image.reshape((-1, num_channels, old_height, old_width)), image.reshape((-1, num_channels, old_height, old_width)),
...@@ -143,9 +141,9 @@ def affine_image_tensor( ...@@ -143,9 +141,9 @@ def affine_image_tensor(
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
width, height = get_image_size(img) _, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
translate_f = [1.0 * t for t in translate] translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
...@@ -169,7 +167,7 @@ def affine_image_pil( ...@@ -169,7 +167,7 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset # it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None: if center is None:
width, height = get_image_size(img) _, height, width = get_dimensions_image_pil(img)
center = [width * 0.5, height * 0.5] center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
...@@ -186,9 +184,9 @@ def rotate_image_tensor( ...@@ -186,9 +184,9 @@ def rotate_image_tensor(
) -> torch.Tensor: ) -> torch.Tensor:
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
width, height = get_image_size(img) _, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
# due to current incoherence of rotation angle direction between affine and rotate implementations # due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle. # we need to set -angle.
...@@ -262,13 +260,13 @@ def _center_crop_compute_crop_anchor( ...@@ -262,13 +260,13 @@ def _center_crop_compute_crop_anchor(
def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img) _, image_height, image_width = get_dimensions_image_tensor(img)
if crop_height > image_height or crop_width > image_width: if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_tensor(img, padding_ltrb, fill=0) img = pad_image_tensor(img, padding_ltrb, fill=0)
image_width, image_height = get_image_size(img) _, image_height, image_width = get_dimensions_image_tensor(img)
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return img return img
...@@ -278,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch ...@@ -278,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch
def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_width, image_height = get_image_size(img) _, image_height, image_width = get_dimensions_image_pil(img)
if crop_height > image_height or crop_width > image_width: if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_pil(img, padding_ltrb, fill=0) img = pad_image_pil(img, padding_ltrb, fill=0)
image_width, image_height = get_image_size(img) _, image_height, image_width = get_dimensions_image_pil(img)
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return img return img
......
...@@ -4,6 +4,10 @@ from torchvision.prototype.features import BoundingBoxFormat, ColorSpace ...@@ -4,6 +4,10 @@ from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
get_dimensions_image_tensor = _FT.get_dimensions
get_dimensions_image_pil = _FP.get_dimensions
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
xyxy = xywh.clone() xyxy = xywh.clone()
xyxy[..., 2:] += xyxy[..., :2] xyxy[..., 2:] += xyxy[..., :2]
......
from typing import Tuple, Union, cast
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]:
if isinstance(image, features.Image):
height, width = image.image_size
return width, height
elif isinstance(image, torch.Tensor):
return cast(Tuple[int, int], tuple(_FT.get_image_size(image)))
if isinstance(image, PIL.Image.Image):
return cast(Tuple[int, int], tuple(_FP.get_image_size(image)))
else:
raise TypeError(f"unable to get image size from object of type {type(image).__name__}")
def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int:
if isinstance(image, features.Image):
return image.num_channels
elif isinstance(image, torch.Tensor):
return _FT.get_image_num_channels(image)
if isinstance(image, PIL.Image.Image):
return cast(int, _FP.get_image_num_channels(image))
else:
raise TypeError(f"unable to get num channels from object of type {type(image).__name__}")
...@@ -220,13 +220,13 @@ class AutoAugment(torch.nn.Module): ...@@ -220,13 +220,13 @@ class AutoAugment(torch.nn.Module):
else: else:
raise ValueError(f"The provided policy {policy} is not recognized.") raise ValueError(f"The provided policy {policy} is not recognized.")
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
return { return {
# op_name: (magnitudes, signed) # op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True), "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True), "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True), "Color": (torch.linspace(0.0, 0.9, num_bins), True),
...@@ -260,15 +260,16 @@ class AutoAugment(torch.nn.Module): ...@@ -260,15 +260,16 @@ class AutoAugment(torch.nn.Module):
PIL Image or Tensor: AutoAugmented image. PIL Image or Tensor: AutoAugmented image.
""" """
fill = self.fill fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor): if isinstance(img, Tensor):
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img) fill = [float(fill)] * channels
elif fill is not None: elif fill is not None:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
transform_id, probs, signs = self.get_params(len(self.policies)) transform_id, probs, signs = self.get_params(len(self.policies))
op_meta = self._augmentation_space(10, F.get_image_size(img)) op_meta = self._augmentation_space(10, (height, width))
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]): for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
if probs[i] <= p: if probs[i] <= p:
magnitudes, signed = op_meta[op_name] magnitudes, signed = op_meta[op_name]
...@@ -317,14 +318,14 @@ class RandAugment(torch.nn.Module): ...@@ -317,14 +318,14 @@ class RandAugment(torch.nn.Module):
self.interpolation = interpolation self.interpolation = interpolation
self.fill = fill self.fill = fill
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
return { return {
# op_name: (magnitudes, signed) # op_name: (magnitudes, signed)
"Identity": (torch.tensor(0.0), False), "Identity": (torch.tensor(0.0), False),
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True), "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True), "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True), "Color": (torch.linspace(0.0, 0.9, num_bins), True),
...@@ -344,13 +345,14 @@ class RandAugment(torch.nn.Module): ...@@ -344,13 +345,14 @@ class RandAugment(torch.nn.Module):
PIL Image or Tensor: Transformed image. PIL Image or Tensor: Transformed image.
""" """
fill = self.fill fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor): if isinstance(img, Tensor):
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img) fill = [float(fill)] * channels
elif fill is not None: elif fill is not None:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) op_meta = self._augmentation_space(self.num_magnitude_bins, (height, width))
for _ in range(self.num_ops): for _ in range(self.num_ops):
op_index = int(torch.randint(len(op_meta), (1,)).item()) op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index] op_name = list(op_meta.keys())[op_index]
...@@ -429,9 +431,10 @@ class TrivialAugmentWide(torch.nn.Module): ...@@ -429,9 +431,10 @@ class TrivialAugmentWide(torch.nn.Module):
PIL Image or Tensor: Transformed image. PIL Image or Tensor: Transformed image.
""" """
fill = self.fill fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor): if isinstance(img, Tensor):
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img) fill = [float(fill)] * channels
elif fill is not None: elif fill is not None:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
...@@ -503,13 +506,13 @@ class AugMix(torch.nn.Module): ...@@ -503,13 +506,13 @@ class AugMix(torch.nn.Module):
self.interpolation = interpolation self.interpolation = interpolation
self.fill = fill self.fill = fill
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[Tensor, bool]]:
s = { s = {
# op_name: (magnitudes, signed) # op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), "TranslateX": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), "TranslateY": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True), "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), "Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False), "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
...@@ -547,16 +550,17 @@ class AugMix(torch.nn.Module): ...@@ -547,16 +550,17 @@ class AugMix(torch.nn.Module):
PIL Image or Tensor: Transformed image. PIL Image or Tensor: Transformed image.
""" """
fill = self.fill fill = self.fill
channels, height, width = F.get_dimensions(orig_img)
if isinstance(orig_img, Tensor): if isinstance(orig_img, Tensor):
img = orig_img img = orig_img
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img) fill = [float(fill)] * channels
elif fill is not None: elif fill is not None:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
else: else:
img = self._pil_to_tensor(orig_img) img = self._pil_to_tensor(orig_img)
op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img)) op_meta = self._augmentation_space(self._PARAMETER_MAX, (height, width))
orig_dims = list(img.shape) orig_dims = list(img.shape)
batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims) batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
......
...@@ -59,6 +59,23 @@ pil_modes_mapping = { ...@@ -59,6 +59,23 @@ pil_modes_mapping = {
_is_pil_image = F_pil._is_pil_image _is_pil_image = F_pil._is_pil_image
def get_dimensions(img: Tensor) -> List[int]:
"""Returns the dimensions of an image as [channels, height, width].
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
List[int]: The image dimensions.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(get_dimensions)
if isinstance(img, torch.Tensor):
return F_t.get_dimensions(img)
return F_pil.get_dimensions(img)
def get_image_size(img: Tensor) -> List[int]: def get_image_size(img: Tensor) -> List[int]:
"""Returns the size of an image as [width, height]. """Returns the size of an image as [width, height].
...@@ -512,7 +529,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: ...@@ -512,7 +529,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0]) output_size = (output_size[0], output_size[0])
image_width, image_height = get_image_size(img) _, image_height, image_width = get_dimensions(img)
crop_height, crop_width = output_size crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
...@@ -523,7 +540,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: ...@@ -523,7 +540,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0, (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
] ]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0 img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
image_width, image_height = get_image_size(img) _, image_height, image_width = get_dimensions(img)
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return img return img
...@@ -721,7 +738,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten ...@@ -721,7 +738,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
if len(size) != 2: if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.") raise ValueError("Please provide only two dimensions (h, w) for size.")
image_width, image_height = get_image_size(img) _, image_height, image_width = get_dimensions(img)
crop_height, crop_width = size crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}" msg = "Requested crop size {} is bigger than input size {}"
...@@ -1047,9 +1064,9 @@ def rotate( ...@@ -1047,9 +1064,9 @@ def rotate(
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
img_size = get_image_size(img) _, height, width = get_dimensions(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
# due to current incoherence of rotation angle direction between affine and rotate implementations # due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle. # we need to set -angle.
...@@ -1167,22 +1184,22 @@ def affine( ...@@ -1167,22 +1184,22 @@ def affine(
if center is not None and not isinstance(center, (list, tuple)): if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence") raise TypeError("Argument center should be a sequence")
img_size = get_image_size(img) _, height, width = get_dimensions(img)
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) # center = (width * 0.5 + 0.5, height * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset # it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None: if center is None:
center = [img_size[0] * 0.5, img_size[1] * 0.5] center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
pil_interpolation = pil_modes_mapping[interpolation] pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
img_size = get_image_size(img) _, height, width = get_dimensions(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)] center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
translate_f = [1.0 * t for t in translate] translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
......
...@@ -20,6 +20,15 @@ def _is_pil_image(img: Any) -> bool: ...@@ -20,6 +20,15 @@ def _is_pil_image(img: Any) -> bool:
return isinstance(img, Image.Image) return isinstance(img, Image.Image)
@torch.jit.unused
def get_dimensions(img: Any) -> List[int]:
if _is_pil_image(img):
channels = len(img.getbands())
width, height = img.size
return [channels, height, width]
raise TypeError(f"Unexpected type {type(img)}")
@torch.jit.unused @torch.jit.unused
def get_image_size(img: Any) -> List[int]: def get_image_size(img: Any) -> List[int]:
if _is_pil_image(img): if _is_pil_image(img):
...@@ -30,7 +39,7 @@ def get_image_size(img: Any) -> List[int]: ...@@ -30,7 +39,7 @@ def get_image_size(img: Any) -> List[int]:
@torch.jit.unused @torch.jit.unused
def get_image_num_channels(img: Any) -> int: def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img): if _is_pil_image(img):
return 1 if img.mode == "L" else 3 return len(img.getbands())
raise TypeError(f"Unexpected type {type(img)}") raise TypeError(f"Unexpected type {type(img)}")
......
...@@ -21,6 +21,13 @@ def _assert_threshold(img: Tensor, threshold: float) -> None: ...@@ -21,6 +21,13 @@ def _assert_threshold(img: Tensor, threshold: float) -> None:
raise TypeError("Threshold should be less than bound of img.") raise TypeError("Threshold should be less than bound of img.")
def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
height, width = img.shape[-2:]
return [channels, height, width]
def get_image_size(img: Tensor) -> List[int]: def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image # Returns (w, h) of tensor image
_assert_image_tensor(img) _assert_image_tensor(img)
...@@ -28,6 +35,7 @@ def get_image_size(img: Tensor) -> List[int]: ...@@ -28,6 +35,7 @@ def get_image_size(img: Tensor) -> List[int]:
def get_image_num_channels(img: Tensor) -> int: def get_image_num_channels(img: Tensor) -> int:
_assert_image_tensor(img)
if img.ndim == 2: if img.ndim == 2:
return 1 return 1
elif img.ndim > 2: elif img.ndim > 2:
...@@ -55,7 +63,7 @@ def _max_value(dtype: torch.dtype) -> float: ...@@ -55,7 +63,7 @@ def _max_value(dtype: torch.dtype) -> float:
def _assert_channels(img: Tensor, permitted: List[int]) -> None: def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = get_image_num_channels(img) c = get_dimensions(img)[0]
if c not in permitted: if c not in permitted:
raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}") raise TypeError(f"Input image tensor permitted channel values are {permitted}, but found {c}")
...@@ -127,7 +135,7 @@ def hflip(img: Tensor) -> Tensor: ...@@ -127,7 +135,7 @@ def hflip(img: Tensor) -> Tensor:
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
w, h = get_image_size(img) _, h, w = get_dimensions(img)
right = left + width right = left + width
bottom = top + height bottom = top + height
...@@ -175,7 +183,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: ...@@ -175,7 +183,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
_assert_channels(img, [3, 1]) _assert_channels(img, [3, 1])
c = get_image_num_channels(img) c = get_dimensions(img)[0]
dtype = img.dtype if torch.is_floating_point(img) else torch.float32 dtype = img.dtype if torch.is_floating_point(img) else torch.float32
if c == 3: if c == 3:
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)
...@@ -195,7 +203,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -195,7 +203,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
_assert_channels(img, [1, 3]) _assert_channels(img, [1, 3])
if get_image_num_channels(img) == 1: # Match PIL behaviour if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img return img
orig_dtype = img.dtype orig_dtype = img.dtype
...@@ -222,7 +230,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: ...@@ -222,7 +230,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
_assert_channels(img, [1, 3]) _assert_channels(img, [1, 3])
if get_image_num_channels(img) == 1: # Match PIL behaviour if get_dimensions(img)[0] == 1: # Match PIL behaviour
return img return img
return _blend(img, rgb_to_grayscale(img), saturation_factor) return _blend(img, rgb_to_grayscale(img), saturation_factor)
...@@ -451,7 +459,7 @@ def resize( ...@@ -451,7 +459,7 @@ def resize(
if antialias and interpolation not in ["bilinear", "bicubic"]: if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
w, h = get_image_size(img) _, h, w = get_dimensions(img)
if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w) short, long = (w, h) if w <= h else (h, w)
...@@ -518,7 +526,7 @@ def _assert_grid_transform_inputs( ...@@ -518,7 +526,7 @@ def _assert_grid_transform_inputs(
warnings.warn("Argument fill should be either int, float, tuple or list") warnings.warn("Argument fill should be either int, float, tuple or list")
# Check fill # Check fill
num_channels = get_image_num_channels(img) num_channels = get_dimensions(img)[0]
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = ( msg = (
"The number of elements in 'fill' cannot broadcast to match the number of " "The number of elements in 'fill' cannot broadcast to match the number of "
......
...@@ -628,7 +628,7 @@ class RandomCrop(torch.nn.Module): ...@@ -628,7 +628,7 @@ class RandomCrop(torch.nn.Module):
Returns: Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
""" """
w, h = F.get_image_size(img) _, h, w = F.get_dimensions(img)
th, tw = output_size th, tw = output_size
if h + 1 < th or w + 1 < tw: if h + 1 < th or w + 1 < tw:
...@@ -663,7 +663,7 @@ class RandomCrop(torch.nn.Module): ...@@ -663,7 +663,7 @@ class RandomCrop(torch.nn.Module):
if self.padding is not None: if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode) img = F.pad(img, self.padding, self.fill, self.padding_mode)
width, height = F.get_image_size(img) _, height, width = F.get_dimensions(img)
# pad the width if needed # pad the width if needed
if self.pad_if_needed and width < self.size[1]: if self.pad_if_needed and width < self.size[1]:
padding = [self.size[1] - width, 0] padding = [self.size[1] - width, 0]
...@@ -793,14 +793,14 @@ class RandomPerspective(torch.nn.Module): ...@@ -793,14 +793,14 @@ class RandomPerspective(torch.nn.Module):
""" """
fill = self.fill fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor): if isinstance(img, Tensor):
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img) fill = [float(fill)] * channels
else: else:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
if torch.rand(1) < self.p: if torch.rand(1) < self.p:
width, height = F.get_image_size(img)
startpoints, endpoints = self.get_params(width, height, self.distortion_scale) startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation, fill) return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
return img return img
...@@ -910,7 +910,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -910,7 +910,7 @@ class RandomResizedCrop(torch.nn.Module):
tuple: params (i, j, h, w) to be passed to ``crop`` for a random tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop. sized crop.
""" """
width, height = F.get_image_size(img) _, height, width = F.get_dimensions(img)
area = height * width area = height * width
log_ratio = torch.log(torch.tensor(ratio)) log_ratio = torch.log(torch.tensor(ratio))
...@@ -1339,9 +1339,10 @@ class RandomRotation(torch.nn.Module): ...@@ -1339,9 +1339,10 @@ class RandomRotation(torch.nn.Module):
PIL Image or Tensor: Rotated image. PIL Image or Tensor: Rotated image.
""" """
fill = self.fill fill = self.fill
channels, _, _ = F.get_dimensions(img)
if isinstance(img, Tensor): if isinstance(img, Tensor):
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img) fill = [float(fill)] * channels
else: else:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
angle = self.get_params(self.degrees) angle = self.get_params(self.degrees)
...@@ -1519,13 +1520,14 @@ class RandomAffine(torch.nn.Module): ...@@ -1519,13 +1520,14 @@ class RandomAffine(torch.nn.Module):
PIL Image or Tensor: Affine transformed image. PIL Image or Tensor: Affine transformed image.
""" """
fill = self.fill fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor): if isinstance(img, Tensor):
if isinstance(fill, (int, float)): if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img) fill = [float(fill)] * channels
else: else:
fill = [float(f) for f in fill] fill = [float(f) for f in fill]
img_size = F.get_image_size(img) img_size = [width, height] # flip for keeping BC on get_params call
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
...@@ -1608,7 +1610,7 @@ class RandomGrayscale(torch.nn.Module): ...@@ -1608,7 +1610,7 @@ class RandomGrayscale(torch.nn.Module):
Returns: Returns:
PIL Image or Tensor: Randomly grayscaled image. PIL Image or Tensor: Randomly grayscaled image.
""" """
num_output_channels = F.get_image_num_channels(img) num_output_channels, _, _ = F.get_dimensions(img)
if torch.rand(1) < self.p: if torch.rand(1) < self.p:
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
return img return img
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment