Unverified Commit 857c0303 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

replace `query_image` with `query_image_dimensions` (#6459)

* add dispatcher for erase image kernels

* simplify RandomErasing

* replace query_image with query_image_dimensions

* use value sentinel and fix test for RandomErasing

* image_dimensions -> chw
parent cff5435f
...@@ -972,51 +972,28 @@ class TestRandomErasing: ...@@ -972,51 +972,28 @@ class TestRandomErasing:
assert 0 <= i <= image.image_size[0] - h assert 0 <= i <= image.image_size[0] - h
assert 0 <= j <= image.image_size[1] - w assert 0 <= j <= image.image_size[1] - w
@pytest.mark.parametrize("p", [0.0, 1.0]) def test__transform(self, mocker):
@pytest.mark.parametrize( transform = transforms.RandomErasing()
"inpt_type", transform._transformed_types = (mocker.MagicMock,)
[
(torch.Tensor, {"shape": (3, 24, 32)}),
(PIL.Image.Image, {"size": (24, 32), "mode": "RGB"}),
],
)
def test__transform(self, p, inpt_type, mocker):
value = 1.0
transform = transforms.RandomErasing(p=p, value=value)
inpt = mocker.MagicMock(spec=inpt_type[0], **inpt_type[1]) i_sentinel = mocker.MagicMock()
erase_image_tensor_inpt = inpt j_sentinel = mocker.MagicMock()
fn = mocker.patch( h_sentinel = mocker.MagicMock()
"torchvision.prototype.transforms.functional.erase_image_tensor", w_sentinel = mocker.MagicMock()
return_value=mocker.MagicMock(spec=torch.Tensor), v_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._augment.RandomErasing._get_params",
return_value=dict(i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel),
) )
if inpt_type[0] == PIL.Image.Image:
erase_image_tensor_inpt = mocker.MagicMock(spec=torch.Tensor)
# vfdev-5: I do not know how to patch pil_to_tensor if it is already imported inpt_sentinel = mocker.MagicMock()
# TODO: patch pil_to_tensor and run below checks for PIL.Image.Image inputs
if p > 0.0:
return
mocker.patch( mock = mocker.patch("torchvision.prototype.transforms._augment.F.erase")
"torchvision.transforms.functional.pil_to_tensor", transform(inpt_sentinel)
return_value=erase_image_tensor_inpt,
)
mocker.patch(
"torchvision.transforms.functional.to_pil_image",
return_value=mocker.MagicMock(spec=PIL.Image.Image),
)
# Let's mock transform._get_params to control the output: mock.assert_called_once_with(
transform._get_params = mocker.MagicMock() inpt_sentinel, i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel
output = transform(inpt) )
print(inpt_type)
assert isinstance(output, inpt_type[0])
params = transform._get_params(inpt)
if p > 0.0:
fn.assert_called_once_with(erase_image_tensor_inpt, **params)
else:
assert fn.call_count == 0
class TestTransform: class TestTransform:
......
...@@ -9,7 +9,7 @@ from torchvision.prototype import features ...@@ -9,7 +9,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F from torchvision.prototype.transforms import functional as F
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image from ._utils import has_any, is_simple_tensor, query_chw
class RandomErasing(_RandomApplyTransform): class RandomErasing(_RandomApplyTransform):
...@@ -38,8 +38,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -38,8 +38,7 @@ class RandomErasing(_RandomApplyTransform):
self.value = value self.value = value
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) img_c, img_h, img_w = query_chw(sample)
img_c, img_h, img_w = get_image_dimensions(image)
if isinstance(self.value, (int, float)): if isinstance(self.value, (int, float)):
value = [self.value] value = [self.value]
...@@ -81,20 +80,15 @@ class RandomErasing(_RandomApplyTransform): ...@@ -81,20 +80,15 @@ class RandomErasing(_RandomApplyTransform):
j = torch.randint(0, img_w - w + 1, size=(1,)).item() j = torch.randint(0, img_w - w + 1, size=(1,)).item()
break break
else: else:
i, j, h, w, v = 0, 0, img_h, img_w, image i, j, h, w, v = 0, 0, img_h, img_w, None
return dict(i=i, j=j, h=h, w=w, v=v) return dict(i=i, j=j, h=h, w=w, v=v)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if is_simple_tensor(inpt) or isinstance(inpt, features.Image): if params["v"] is not None:
output = F.erase_image_tensor(inpt, **params) inpt = F.erase(inpt, **params)
if isinstance(inpt, features.Image):
return features.Image.new_like(inpt, output) return inpt
return output
elif isinstance(inpt, PIL.Image.Image):
return F.erase_image_pil(inpt, **params)
else:
return inpt
class _BaseMixupCutmix(_RandomApplyTransform): class _BaseMixupCutmix(_RandomApplyTransform):
...@@ -145,8 +139,7 @@ class RandomCutmix(_BaseMixupCutmix): ...@@ -145,8 +139,7 @@ class RandomCutmix(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
lam = float(self._dist.sample(())) lam = float(self._dist.sample(()))
image = query_image(sample) _, H, W = query_chw(sample)
_, H, W = get_image_dimensions(image)
r_x = torch.randint(W, ()) r_x = torch.randint(W, ())
r_y = torch.randint(H, ()) r_y = torch.randint(H, ())
......
...@@ -10,7 +10,7 @@ from torchvision.prototype.transforms import functional as F, Transform ...@@ -10,7 +10,7 @@ from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.autoaugment import AutoAugmentPolicy from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from ._utils import get_image_dimensions, is_simple_tensor from ._utils import get_chw, is_simple_tensor
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
...@@ -281,7 +281,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -281,7 +281,7 @@ class AutoAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample) id, image = self._extract_image(sample)
num_channels, height, width = get_image_dimensions(image) num_channels, height, width = get_chw(image)
fill = self._parse_fill(image, num_channels) fill = self._parse_fill(image, num_channels)
policy = self._policies[int(torch.randint(len(self._policies), ()))] policy = self._policies[int(torch.randint(len(self._policies), ()))]
...@@ -354,7 +354,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -354,7 +354,7 @@ class RandAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample) id, image = self._extract_image(sample)
num_channels, height, width = get_image_dimensions(image) num_channels, height, width = get_chw(image)
fill = self._parse_fill(image, num_channels) fill = self._parse_fill(image, num_channels)
for _ in range(self.num_ops): for _ in range(self.num_ops):
...@@ -412,7 +412,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -412,7 +412,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample) id, image = self._extract_image(sample)
num_channels, height, width = get_image_dimensions(image) num_channels, height, width = get_chw(image)
fill = self._parse_fill(image, num_channels) fill = self._parse_fill(image, num_channels)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
...@@ -481,7 +481,7 @@ class AugMix(_AutoAugmentBase): ...@@ -481,7 +481,7 @@ class AugMix(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image = self._extract_image(sample) id, orig_image = self._extract_image(sample)
num_channels, height, width = get_image_dimensions(orig_image) num_channels, height, width = get_chw(orig_image)
fill = self._parse_fill(orig_image, num_channels) fill = self._parse_fill(orig_image, num_channels)
if isinstance(orig_image, torch.Tensor): if isinstance(orig_image, torch.Tensor):
......
...@@ -8,7 +8,7 @@ from torchvision.prototype.transforms import functional as F, Transform ...@@ -8,7 +8,7 @@ from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms import functional as _F from torchvision.transforms import functional as _F
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, is_simple_tensor, query_image from ._utils import is_simple_tensor, query_chw
T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image) T = TypeVar("T", features.Image, torch.Tensor, PIL.Image.Image)
...@@ -101,8 +101,7 @@ class RandomPhotometricDistort(Transform): ...@@ -101,8 +101,7 @@ class RandomPhotometricDistort(Transform):
self.p = p self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) num_channels, _, _ = query_chw(sample)
num_channels, _, _ = get_image_dimensions(image)
return dict( return dict(
zip( zip(
["brightness", "contrast1", "saturation", "hue", "contrast2"], ["brightness", "contrast1", "saturation", "hue", "contrast2"],
......
...@@ -15,7 +15,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angl ...@@ -15,7 +15,7 @@ from torchvision.transforms.transforms import _check_sequence_input, _setup_angl
from typing_extensions import Literal from typing_extensions import Literal
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_bounding_box, query_image from ._utils import has_all, has_any, is_simple_tensor, query_bounding_box, query_chw
class RandomHorizontalFlip(_RandomApplyTransform): class RandomHorizontalFlip(_RandomApplyTransform):
...@@ -92,8 +92,7 @@ class RandomResizedCrop(Transform): ...@@ -92,8 +92,7 @@ class RandomResizedCrop(Transform):
# vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples # vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples
# What if we have multiple images/bboxes/masks of different sizes ? # What if we have multiple images/bboxes/masks of different sizes ?
# TODO: let's support bbox or mask in samples without image # TODO: let's support bbox or mask in samples without image
image = query_image(sample) _, height, width = query_chw(sample)
_, height, width = get_image_dimensions(image)
area = height * width area = height * width
log_ratio = torch.log(torch.tensor(self.ratio)) log_ratio = torch.log(torch.tensor(self.ratio))
...@@ -269,8 +268,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -269,8 +268,7 @@ class RandomZoomOut(_RandomApplyTransform):
raise ValueError(f"Invalid canvas side range provided {side_range}.") raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) orig_c, orig_h, orig_w = query_chw(sample)
orig_c, orig_h, orig_w = get_image_dimensions(image)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r) canvas_width = int(orig_w * r)
...@@ -373,8 +371,7 @@ class RandomAffine(Transform): ...@@ -373,8 +371,7 @@ class RandomAffine(Transform):
# Get image size # Get image size
# TODO: make it work with bboxes and segm masks # TODO: make it work with bboxes and segm masks
image = query_image(sample) _, height, width = query_chw(sample)
_, height, width = get_image_dimensions(image)
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
if self.translate is not None: if self.translate is not None:
...@@ -435,8 +432,7 @@ class RandomCrop(Transform): ...@@ -435,8 +432,7 @@ class RandomCrop(Transform):
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) _, height, width = query_chw(sample)
_, height, width = get_image_dimensions(image)
if self.padding is not None: if self.padding is not None:
# update height, width with static padding data # update height, width with static padding data
...@@ -516,8 +512,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -516,8 +512,7 @@ class RandomPerspective(_RandomApplyTransform):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size # Get image size
# TODO: make it work with bboxes and segm masks # TODO: make it work with bboxes and segm masks
image = query_image(sample) _, height, width = query_chw(sample)
_, height, width = get_image_dimensions(image)
distortion_scale = self.distortion_scale distortion_scale = self.distortion_scale
...@@ -589,8 +584,7 @@ class ElasticTransform(Transform): ...@@ -589,8 +584,7 @@ class ElasticTransform(Transform):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
# Get image size # Get image size
# TODO: make it work with bboxes and segm masks # TODO: make it work with bboxes and segm masks
image = query_image(sample) _, *size = query_chw(sample)
_, *size = get_image_dimensions(image)
dx = torch.rand([1, 1] + size) * 2 - 1 dx = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[0] > 0.0: if self.sigma[0] > 0.0:
...@@ -643,9 +637,7 @@ class RandomIoUCrop(Transform): ...@@ -643,9 +637,7 @@ class RandomIoUCrop(Transform):
self.trials = trials self.trials = trials
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
_, orig_h, orig_w = query_chw(sample)
image = query_image(sample)
_, orig_h, orig_w = get_image_dimensions(image)
bboxes = query_bounding_box(sample) bboxes = query_bounding_box(sample)
while True: while True:
...@@ -743,8 +735,7 @@ class ScaleJitter(Transform): ...@@ -743,8 +735,7 @@ class ScaleJitter(Transform):
self.interpolation = interpolation self.interpolation = interpolation
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) _, orig_height, orig_width = query_chw(sample)
_, orig_height, orig_width = get_image_dimensions(image)
r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
new_width = int(self.target_size[1] * r) new_width = int(self.target_size[1] * r)
...@@ -769,8 +760,7 @@ class RandomShortestSize(Transform): ...@@ -769,8 +760,7 @@ class RandomShortestSize(Transform):
self.interpolation = interpolation self.interpolation = interpolation
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) _, orig_height, orig_width = query_chw(sample)
_, orig_height, orig_width = get_image_dimensions(image)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
...@@ -799,8 +789,7 @@ class FixedSizeCrop(Transform): ...@@ -799,8 +789,7 @@ class FixedSizeCrop(Transform):
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample) _, height, width = query_chw(sample)
_, height, width = get_image_dimensions(image)
new_height = min(height, self.crop_height) new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width) new_width = min(width, self.crop_width)
......
...@@ -3,20 +3,12 @@ from typing import Any, Callable, Tuple, Type, Union ...@@ -3,20 +3,12 @@ from typing import Any, Callable, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str
from torchvision.prototype import features from torchvision.prototype import features
from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor from .functional._meta import get_dimensions_image_pil, get_dimensions_image_tensor
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
flat_sample, _ = tree_flatten(sample)
for i in flat_sample:
if type(i) == torch.Tensor or isinstance(i, (PIL.Image.Image, features.Image)):
return i
raise TypeError("No image was found in the sample")
def query_bounding_box(sample: Any) -> features.BoundingBox: def query_bounding_box(sample: Any) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample) flat_sample, _ = tree_flatten(sample)
for i in flat_sample: for i in flat_sample:
...@@ -26,7 +18,7 @@ def query_bounding_box(sample: Any) -> features.BoundingBox: ...@@ -26,7 +18,7 @@ def query_bounding_box(sample: Any) -> features.BoundingBox:
raise TypeError("No bounding box was found in the sample") raise TypeError("No bounding box was found in the sample")
def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
if isinstance(image, features.Image): if isinstance(image, features.Image):
channels = image.num_channels channels = image.num_channels
height, width = image.image_size height, width = image.image_size
...@@ -39,6 +31,20 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im ...@@ -39,6 +31,20 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im
return channels, height, width return channels, height, width
def query_chw(sample: Any) -> Tuple[int, int, int]:
flat_sample, _ = tree_flatten(sample)
image_dimensionss = {
get_chw(item)
for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image)) or is_simple_tensor(item)
}
if not image_dimensionss:
raise TypeError("No image was found in the sample")
elif len(image_dimensionss) > 2:
raise TypeError(f"Found multiple image dimensions in the sample: {sequence_to_str(sorted(image_dimensionss))}")
return image_dimensionss.pop()
def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample) flat_sample, _ = tree_flatten(sample)
for type_or_check in types_or_checks: for type_or_check in types_or_checks:
......
...@@ -7,7 +7,7 @@ from ._meta import ( ...@@ -7,7 +7,7 @@ from ._meta import (
convert_color_space, convert_color_space,
) # usort: skip ) # usort: skip
from ._augment import erase_image_pil, erase_image_tensor from ._augment import erase, erase_image_pil, erase_image_tensor
from ._color import ( from ._color import (
adjust_brightness, adjust_brightness,
adjust_brightness_image_pil, adjust_brightness_image_pil,
......
from typing import Any
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.functional import pil_to_tensor, to_pil_image
...@@ -14,3 +17,13 @@ def erase_image_pil( ...@@ -14,3 +17,13 @@ def erase_image_pil(
t_img = pil_to_tensor(img) t_img = pil_to_tensor(img)
output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return to_pil_image(output, mode=img.mode) return to_pil_image(output, mode=img.mode)
def erase(inpt: Any, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False) -> Any:
if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if isinstance(inpt, features.Image):
output = features.Image.new_like(inpt, output)
return output
else: # isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment