Unverified Commit e3238e5a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

only flatten a pytree once (#6767)

parent dc5fd831
......@@ -437,7 +437,7 @@ class TestRandomZoomOut:
image = mocker.MagicMock(spec=features.Image)
h, w = image.spatial_size = (24, 32)
params = transform._get_params(image)
params = transform._get_params([image])
assert len(params["padding"]) == 4
assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
......@@ -462,7 +462,7 @@ class TestRandomZoomOut:
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)
params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill)
......@@ -623,7 +623,7 @@ class TestRandomAffine:
h, w = image.spatial_size
transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear)
params = transform._get_params(image)
params = transform._get_params([image])
if not isinstance(degrees, (list, tuple)):
assert -degrees <= params["angle"] <= degrees
......@@ -690,7 +690,7 @@ class TestRandomAffine:
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
params = transform._get_params(inpt)
params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)
......@@ -722,7 +722,7 @@ class TestRandomCrop:
h, w = image.spatial_size
transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed)
params = transform._get_params(image)
params = transform._get_params([image])
if padding is not None:
if isinstance(padding, int):
......@@ -793,7 +793,7 @@ class TestRandomCrop:
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
params = transform._get_params(inpt)
params = transform._get_params([inpt])
if padding is None and not pad_if_needed:
fn_crop.assert_called_once_with(
inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
......@@ -832,7 +832,7 @@ class TestGaussianBlur:
@pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0]])
def test__get_params(self, sigma):
transform = transforms.GaussianBlur(3, sigma=sigma)
params = transform._get_params(None)
params = transform._get_params([])
if isinstance(sigma, float):
assert params["sigma"][0] == params["sigma"][1] == 10
......@@ -867,7 +867,7 @@ class TestGaussianBlur:
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
params = transform._get_params(inpt)
params = transform._get_params([inpt])
fn.assert_called_once_with(inpt, kernel_size, **params)
......@@ -912,7 +912,7 @@ class TestRandomPerspective:
image.num_channels = 3
image.spatial_size = (24, 32)
params = transform._get_params(image)
params = transform._get_params([image])
h, w = image.spatial_size
assert "perspective_coeffs" in params
......@@ -935,7 +935,7 @@ class TestRandomPerspective:
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)
params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
......@@ -973,7 +973,7 @@ class TestElasticTransform:
image.num_channels = 3
image.spatial_size = (24, 32)
params = transform._get_params(image)
params = transform._get_params([image])
h, w = image.spatial_size
displacement = params["displacement"]
......@@ -1006,7 +1006,7 @@ class TestElasticTransform:
# Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock()
_ = transform(inpt)
params = transform._get_params(inpt)
params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
......@@ -1035,7 +1035,7 @@ class TestRandomErasing:
transform = transforms.RandomErasing(value=[1, 2, 3, 4])
with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
transform._get_params(image)
transform._get_params([image])
@pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"])
def test__get_params(self, value, mocker):
......@@ -1044,7 +1044,7 @@ class TestRandomErasing:
image.spatial_size = (24, 32)
transform = transforms.RandomErasing(value=value)
params = transform._get_params(image)
params = transform._get_params([image])
v = params["v"]
h, w = params["h"], params["w"]
......@@ -1197,6 +1197,7 @@ class TestContainers:
[
[transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
[transforms.Pad(2), lambda x: 2.0 * x, transforms.RandomCrop(28)],
],
)
def test_ctor(self, transform_cls, trfms):
......@@ -1339,7 +1340,7 @@ class TestScaleJitter:
n_samples = 5
for _ in range(n_samples):
params = transform._get_params(sample)
params = transform._get_params([sample])
assert "size" in params
size = params["size"]
......@@ -1386,7 +1387,7 @@ class TestRandomShortestSize:
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size)
sample = mocker.MagicMock(spec=features.Image, num_channels=3, spatial_size=spatial_size)
params = transform._get_params(sample)
params = transform._get_params([sample])
assert "size" in params
size = params["size"]
......@@ -1554,13 +1555,13 @@ class TestFixedSizeCrop:
transform = transforms.FixedSizeCrop(size=crop_size)
sample = dict(
image=make_image(size=spatial_size, color_space=features.ColorSpace.RGB),
bounding_boxes=make_bounding_box(
flat_inputs = [
make_image(size=spatial_size, color_space=features.ColorSpace.RGB),
make_bounding_box(
format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape
),
)
params = transform._get_params(sample)
]
params = transform._get_params(flat_inputs)
assert params["needs_crop"]
assert params["height"] <= crop_size[0]
......@@ -1759,7 +1760,7 @@ class TestRandomResize:
transform = transforms.RandomResize(min_size=min_size, max_size=max_size)
for _ in range(10):
params = transform._get_params(None)
params = transform._get_params([])
assert isinstance(params["size"], list) and len(params["size"]) == 1
size = params["size"][0]
......
......@@ -639,7 +639,7 @@ class TestContainerTransforms:
prototype_transform = prototype_transforms.RandomApply(
[
prototype_transforms.Resize(256),
legacy_transforms.CenterCrop(224),
prototype_transforms.CenterCrop(224),
],
p=p,
)
......
......@@ -45,8 +45,8 @@ class RandomErasing(_RandomApplyTransform):
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, sample: Any) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)
if isinstance(self.value, (int, float)):
value = [self.value]
......@@ -107,13 +107,13 @@ class _BaseMixupCutmix(_RandomApplyTransform):
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def _check_inputs(self, sample: Any) -> None:
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_any(sample, features.Image, features.Video, features.is_simple_tensor)
and has_any(sample, features.OneHotLabel)
has_any(flat_inputs, features.Image, features.Video, features.is_simple_tensor)
and has_any(flat_inputs, features.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(sample, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
if has_any(flat_inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
)
......@@ -127,7 +127,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
class RandomMixup(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]:
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(())))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
......@@ -150,10 +150,10 @@ class RandomMixup(_BaseMixupCutmix):
class RandomCutmix(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]:
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(()))
H, W = query_spatial_size(sample)
H, W = query_spatial_size(flat_inputs)
r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
......@@ -344,9 +344,9 @@ class SimpleCopyPaste(_RandomApplyTransform):
c3 += 1
def forward(self, *inputs: Any) -> Any:
flat_sample, spec = tree_flatten(inputs)
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
images, targets = self._extract_image_targets(flat_sample)
images, targets = self._extract_image_targets(flat_inputs)
# images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images
......@@ -384,6 +384,6 @@ class SimpleCopyPaste(_RandomApplyTransform):
output_targets.append(output_target)
# Insert updated images and targets into input flat_sample
self._insert_outputs(flat_sample, output_images, output_targets)
self._insert_outputs(flat_inputs, output_images, output_targets)
return tree_unflatten(flat_sample, spec)
return tree_unflatten(flat_inputs, spec)
......@@ -4,7 +4,7 @@ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeV
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_spatial_size
......@@ -31,16 +31,17 @@ class _AutoAugmentBase(Transform):
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]
def _extract_image_or_video(
def _flatten_and_extract_image_or_video(
self,
sample: Any,
inputs: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[int, Union[features.ImageType, features.VideoType]]:
sample_flat, _ = tree_flatten(sample)
) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[features.ImageType, features.VideoType]]:
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
image_or_videos = []
for id, inpt in enumerate(sample_flat):
for idx, inpt in enumerate(flat_inputs):
if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)):
image_or_videos.append((id, inpt))
image_or_videos.append((idx, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
......@@ -51,12 +52,18 @@ class _AutoAugmentBase(Transform):
f"Auto augment transformations are only properly defined for a single image or video, "
f"but found {len(image_or_videos)}."
)
return image_or_videos[0]
def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:
sample_flat, spec = tree_flatten(sample)
sample_flat[id] = item
return tree_unflatten(sample_flat, spec)
idx, image_or_video = image_or_videos[0]
return (flat_inputs, spec, idx), image_or_video
def _unflatten_and_insert_image_or_video(
self,
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
image_or_video: Union[features.ImageType, features.VideoType],
) -> Any:
flat_inputs, spec, idx = flat_inputs_with_spec
flat_inputs[idx] = image_or_video
return tree_unflatten(flat_inputs, spec)
def _apply_image_or_video_transform(
self,
......@@ -275,9 +282,7 @@ class AutoAugment(_AutoAugmentBase):
raise ValueError(f"The provided policy {policy} is not recognized.")
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image_or_video = self._extract_image_or_video(sample)
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
......@@ -300,7 +305,7 @@ class AutoAugment(_AutoAugmentBase):
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._put_into_sample(sample, id, image_or_video)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class RandAugment(_AutoAugmentBase):
......@@ -346,9 +351,7 @@ class RandAugment(_AutoAugmentBase):
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image_or_video = self._extract_image_or_video(sample)
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video)
for _ in range(self.num_ops):
......@@ -364,7 +367,7 @@ class RandAugment(_AutoAugmentBase):
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._put_into_sample(sample, id, image_or_video)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class TrivialAugmentWide(_AutoAugmentBase):
......@@ -400,9 +403,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, image_or_video = self._extract_image_or_video(sample)
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(image_or_video)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -418,7 +419,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._put_into_sample(sample, id, image_or_video)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class AugMix(_AutoAugmentBase):
......@@ -471,8 +472,7 @@ class AugMix(_AutoAugmentBase):
return torch._sample_dirichlet(params)
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image_or_video = self._extract_image_or_video(sample)
flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_spatial_size(orig_image_or_video)
if isinstance(orig_image_or_video, torch.Tensor):
......@@ -525,4 +525,4 @@ class AugMix(_AutoAugmentBase):
elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix)
return self._put_into_sample(sample, id, mix)
return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)
import collections.abc
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
......@@ -53,7 +53,7 @@ class ColorJitter(Transform):
def _generate_value(left: float, right: float) -> float:
return float(torch.distributions.Uniform(left, right).sample())
def _get_params(self, sample: Any) -> Dict[str, Any]:
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
fn_idx = torch.randperm(4)
b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
......@@ -99,8 +99,8 @@ class RandomPhotometricDistort(Transform):
self.saturation = saturation
self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]:
num_channels, *_ = query_chw(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(flat_inputs)
return dict(
zip(
["brightness", "contrast1", "saturation", "hue", "contrast2"],
......
import warnings
from typing import Any, Dict, Union
from typing import Any, Dict, List, Union
import numpy as np
import PIL.Image
......@@ -79,8 +79,8 @@ class RandomGrayscale(_RandomApplyTransform):
super().__init__(p=p)
def _get_params(self, sample: Any) -> Dict[str, Any]:
num_input_channels, *_ = query_chw(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_input_channels, *_ = query_chw(flat_inputs)
return dict(num_input_channels=num_input_channels)
def _transform(
......
......@@ -104,8 +104,8 @@ class RandomResizedCrop(Transform):
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, sample: Any) -> Dict[str, Any]:
height, width = query_spatial_size(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
area = height * width
log_ratio = self._log_ratio
......@@ -184,8 +184,8 @@ class FiveCrop(Transform):
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
return F.five_crop(inpt, self.size)
def _check_inputs(self, sample: Any) -> None:
if has_any(sample, features.BoundingBox, features.Mask):
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
......@@ -201,8 +201,8 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip
def _check_inputs(self, sample: Any) -> None:
if has_any(sample, features.BoundingBox, features.Mask):
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(flat_inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
def _transform(
......@@ -256,8 +256,8 @@ class RandomZoomOut(_RandomApplyTransform):
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, sample: Any) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
......@@ -299,7 +299,7 @@ class RandomRotation(Transform):
self.center = center
def _get_params(self, sample: Any) -> Dict[str, Any]:
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
return dict(angle=angle)
......@@ -355,8 +355,8 @@ class RandomAffine(Transform):
self.center = center
def _get_params(self, sample: Any) -> Dict[str, Any]:
height, width = query_spatial_size(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
if self.translate is not None:
......@@ -417,8 +417,8 @@ class RandomCrop(Transform):
self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]:
padded_height, padded_width = query_spatial_size(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
padded_height, padded_width = query_spatial_size(flat_inputs)
if self.padding is not None:
pad_left, pad_right, pad_top, pad_bottom = self.padding
......@@ -505,8 +505,8 @@ class RandomPerspective(_RandomApplyTransform):
self.interpolation = interpolation
self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]:
height, width = query_spatial_size(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
distortion_scale = self.distortion_scale
......@@ -559,8 +559,8 @@ class ElasticTransform(Transform):
self.interpolation = interpolation
self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]:
size = list(query_spatial_size(sample))
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = list(query_spatial_size(flat_inputs))
dx = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[0] > 0.0:
......@@ -614,20 +614,20 @@ class RandomIoUCrop(Transform):
self.options = sampler_options
self.trials = trials
def _check_inputs(self, sample: Any) -> None:
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image, features.is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel)
has_all(flat_inputs, features.BoundingBox)
and has_any(flat_inputs, PIL.Image.Image, features.Image, features.is_simple_tensor)
and has_any(flat_inputs, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks."
)
def _get_params(self, sample: Any) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(sample)
bboxes = query_bounding_box(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(flat_inputs)
bboxes = query_bounding_box(flat_inputs)
while True:
# sample an option
......@@ -712,8 +712,8 @@ class ScaleJitter(Transform):
self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
......@@ -740,8 +740,8 @@ class RandomShortestSize(Transform):
self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(flat_inputs)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
......@@ -771,20 +771,22 @@ class FixedSizeCrop(Transform):
self.padding_mode = padding_mode
def _check_inputs(self, sample: Any) -> None:
if not has_any(sample, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video):
def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not has_any(flat_inputs, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
)
if has_any(sample, features.BoundingBox) and not has_any(sample, features.Label, features.OneHotLabel):
if has_any(flat_inputs, features.BoundingBox) and not has_any(
flat_inputs, features.Label, features.OneHotLabel
):
raise TypeError(
f"If a BoundingBox is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
)
def _get_params(self, sample: Any) -> Dict[str, Any]:
height, width = query_spatial_size(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(flat_inputs)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)
......@@ -798,7 +800,7 @@ class FixedSizeCrop(Transform):
left = int(offset_width * r)
try:
bounding_boxes = query_bounding_box(sample)
bounding_boxes = query_bounding_box(flat_inputs)
except ValueError:
bounding_boxes = None
......@@ -874,7 +876,7 @@ class RandomResize(Transform):
self.interpolation = interpolation
self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]:
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = int(torch.randint(self.min_size, self.max_size, ()))
return dict(size=[size])
......
import functools
from collections import defaultdict
from typing import Any, Callable, Dict, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Sequence, Type, Union
import PIL.Image
......@@ -134,7 +134,7 @@ class GaussianBlur(Transform):
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
def _get_params(self, sample: Any) -> Dict[str, Any]:
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
return dict(sigma=[sigma, sigma])
......@@ -167,8 +167,8 @@ class RemoveSmallBoundingBoxes(Transform):
super().__init__()
self.min_size = min_size
def _get_params(self, sample: Any) -> Dict[str, Any]:
bounding_box = query_bounding_box(sample)
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
bounding_box = query_bounding_box(flat_inputs)
# TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to
# be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH
......
import enum
from typing import Any, Callable, Dict, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Tuple, Type, Union
import PIL.Image
import torch
......@@ -23,27 +23,27 @@ class Transform(nn.Module):
super().__init__()
_log_api_usage_once(self)
def _check_inputs(self, sample: Any) -> None:
def _check_inputs(self, flat_inputs: List[Any]) -> None:
pass
def _get_params(self, sample: Any) -> Dict[str, Any]:
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
raise NotImplementedError
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
self._check_inputs(sample)
self._check_inputs(flat_inputs)
params = self._get_params(sample)
params = self._get_params(flat_inputs)
flat_inputs, spec = tree_flatten(sample)
flat_outputs = [
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
def extra_repr(self) -> str:
......@@ -73,18 +73,19 @@ class _RandomApplyTransform(Transform):
# early afterwards in case the random check triggers. The same result could be achieved by calling
# `super().forward()` after the random check, but that would call `self._check_inputs` twice.
sample = inputs if len(inputs) > 1 else inputs[0]
inputs = inputs if len(inputs) > 1 else inputs[0]
flat_inputs, spec = tree_flatten(inputs)
self._check_inputs(sample)
self._check_inputs(flat_inputs)
if torch.rand(1) >= self.p:
return sample
return inputs
params = self._get_params(sample)
params = self._get_params(flat_inputs)
flat_inputs, spec = tree_flatten(sample)
flat_outputs = [
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
import functools
import numbers
from collections import defaultdict
from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, Union
import PIL.Image
from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str
from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType
......@@ -73,9 +72,8 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect",
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
def query_bounding_box(sample: Any) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample)
bounding_boxes = {item for item in flat_sample if isinstance(item, features.BoundingBox)}
def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox:
bounding_boxes = {inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)}
if not bounding_boxes:
raise TypeError("No bounding box was found in the sample")
elif len(bounding_boxes) > 1:
......@@ -83,12 +81,11 @@ def query_bounding_box(sample: Any) -> features.BoundingBox:
return bounding_boxes.pop()
def query_chw(sample: Any) -> Tuple[int, int, int]:
flat_sample, _ = tree_flatten(sample)
def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
chws = {
tuple(get_dimensions(item))
for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item)
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if isinstance(inpt, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(inpt)
}
if not chws:
raise TypeError("No image or video was found in the sample")
......@@ -98,13 +95,12 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
return c, h, w
def query_spatial_size(sample: Any) -> Tuple[int, int]:
flat_sample, _ = tree_flatten(sample)
def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]:
sizes = {
tuple(get_spatial_size(item))
for item in flat_sample
if isinstance(item, (features.Image, PIL.Image.Image, features.Video, features.Mask, features.BoundingBox))
or features.is_simple_tensor(item)
tuple(get_spatial_size(inpt))
for inpt in flat_inputs
if isinstance(inpt, (features.Image, PIL.Image.Image, features.Video, features.Mask, features.BoundingBox))
or features.is_simple_tensor(inpt)
}
if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample")
......@@ -121,19 +117,17 @@ def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], boo
return False
def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample)
for obj in flat_sample:
if _isinstance(obj, types_or_checks):
def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for inpt in flat_inputs:
if _isinstance(inpt, types_or_checks):
return True
return False
def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample)
def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
for type_or_check in types_or_checks:
for obj in flat_sample:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
for inpt in flat_inputs:
if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
break
else:
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment