"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fa35750d3ba8c2aadac4ce33515be302e12d00fc"
Unverified Commit e3238e5a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

only flatten a pytree once (#6767)

parent dc5fd831
...@@ -437,7 +437,7 @@ class TestRandomZoomOut: ...@@ -437,7 +437,7 @@ class TestRandomZoomOut:
image = mocker.MagicMock(spec=features.Image) image = mocker.MagicMock(spec=features.Image)
h, w = image.spatial_size = (24, 32) h, w = image.spatial_size = (24, 32)
params = transform._get_params(image) params = transform._get_params([image])
assert len(params["padding"]) == 4 assert len(params["padding"]) == 4
assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
...@@ -462,7 +462,7 @@ class TestRandomZoomOut: ...@@ -462,7 +462,7 @@ class TestRandomZoomOut:
_ = transform(inpt) _ = transform(inpt)
torch.manual_seed(12) torch.manual_seed(12)
torch.rand(1) # random apply changes random state torch.rand(1) # random apply changes random state
params = transform._get_params(inpt) params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill) fn.assert_called_once_with(inpt, **params, fill=fill)
...@@ -623,7 +623,7 @@ class TestRandomAffine: ...@@ -623,7 +623,7 @@ class TestRandomAffine:
h, w = image.spatial_size h, w = image.spatial_size
transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear) transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear)
params = transform._get_params(image) params = transform._get_params([image])
if not isinstance(degrees, (list, tuple)): if not isinstance(degrees, (list, tuple)):
assert -degrees <= params["angle"] <= degrees assert -degrees <= params["angle"] <= degrees
...@@ -690,7 +690,7 @@ class TestRandomAffine: ...@@ -690,7 +690,7 @@ class TestRandomAffine:
torch.manual_seed(12) torch.manual_seed(12)
_ = transform(inpt) _ = transform(inpt)
torch.manual_seed(12) torch.manual_seed(12)
params = transform._get_params(inpt) params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center) fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center)
...@@ -722,7 +722,7 @@ class TestRandomCrop: ...@@ -722,7 +722,7 @@ class TestRandomCrop:
h, w = image.spatial_size h, w = image.spatial_size
transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed) transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed)
params = transform._get_params(image) params = transform._get_params([image])
if padding is not None: if padding is not None:
if isinstance(padding, int): if isinstance(padding, int):
...@@ -793,7 +793,7 @@ class TestRandomCrop: ...@@ -793,7 +793,7 @@ class TestRandomCrop:
torch.manual_seed(12) torch.manual_seed(12)
_ = transform(inpt) _ = transform(inpt)
torch.manual_seed(12) torch.manual_seed(12)
params = transform._get_params(inpt) params = transform._get_params([inpt])
if padding is None and not pad_if_needed: if padding is None and not pad_if_needed:
fn_crop.assert_called_once_with( fn_crop.assert_called_once_with(
inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1] inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1]
...@@ -832,7 +832,7 @@ class TestGaussianBlur: ...@@ -832,7 +832,7 @@ class TestGaussianBlur:
@pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0]]) @pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0]])
def test__get_params(self, sigma): def test__get_params(self, sigma):
transform = transforms.GaussianBlur(3, sigma=sigma) transform = transforms.GaussianBlur(3, sigma=sigma)
params = transform._get_params(None) params = transform._get_params([])
if isinstance(sigma, float): if isinstance(sigma, float):
assert params["sigma"][0] == params["sigma"][1] == 10 assert params["sigma"][0] == params["sigma"][1] == 10
...@@ -867,7 +867,7 @@ class TestGaussianBlur: ...@@ -867,7 +867,7 @@ class TestGaussianBlur:
torch.manual_seed(12) torch.manual_seed(12)
_ = transform(inpt) _ = transform(inpt)
torch.manual_seed(12) torch.manual_seed(12)
params = transform._get_params(inpt) params = transform._get_params([inpt])
fn.assert_called_once_with(inpt, kernel_size, **params) fn.assert_called_once_with(inpt, kernel_size, **params)
...@@ -912,7 +912,7 @@ class TestRandomPerspective: ...@@ -912,7 +912,7 @@ class TestRandomPerspective:
image.num_channels = 3 image.num_channels = 3
image.spatial_size = (24, 32) image.spatial_size = (24, 32)
params = transform._get_params(image) params = transform._get_params([image])
h, w = image.spatial_size h, w = image.spatial_size
assert "perspective_coeffs" in params assert "perspective_coeffs" in params
...@@ -935,7 +935,7 @@ class TestRandomPerspective: ...@@ -935,7 +935,7 @@ class TestRandomPerspective:
_ = transform(inpt) _ = transform(inpt)
torch.manual_seed(12) torch.manual_seed(12)
torch.rand(1) # random apply changes random state torch.rand(1) # random apply changes random state
params = transform._get_params(inpt) params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
...@@ -973,7 +973,7 @@ class TestElasticTransform: ...@@ -973,7 +973,7 @@ class TestElasticTransform:
image.num_channels = 3 image.num_channels = 3
image.spatial_size = (24, 32) image.spatial_size = (24, 32)
params = transform._get_params(image) params = transform._get_params([image])
h, w = image.spatial_size h, w = image.spatial_size
displacement = params["displacement"] displacement = params["displacement"]
...@@ -1006,7 +1006,7 @@ class TestElasticTransform: ...@@ -1006,7 +1006,7 @@ class TestElasticTransform:
# Let's mock transform._get_params to control the output: # Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock() transform._get_params = mocker.MagicMock()
_ = transform(inpt) _ = transform(inpt)
params = transform._get_params(inpt) params = transform._get_params([inpt])
fill = transforms.functional._geometry._convert_fill_arg(fill) fill = transforms.functional._geometry._convert_fill_arg(fill)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
...@@ -1035,7 +1035,7 @@ class TestRandomErasing: ...@@ -1035,7 +1035,7 @@ class TestRandomErasing:
transform = transforms.RandomErasing(value=[1, 2, 3, 4]) transform = transforms.RandomErasing(value=[1, 2, 3, 4])
with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"): with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
transform._get_params(image) transform._get_params([image])
@pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"]) @pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"])
def test__get_params(self, value, mocker): def test__get_params(self, value, mocker):
...@@ -1044,7 +1044,7 @@ class TestRandomErasing: ...@@ -1044,7 +1044,7 @@ class TestRandomErasing:
image.spatial_size = (24, 32) image.spatial_size = (24, 32)
transform = transforms.RandomErasing(value=value) transform = transforms.RandomErasing(value=value)
params = transform._get_params(image) params = transform._get_params([image])
v = params["v"] v = params["v"]
h, w = params["h"], params["w"] h, w = params["h"], params["w"]
...@@ -1197,6 +1197,7 @@ class TestContainers: ...@@ -1197,6 +1197,7 @@ class TestContainers:
[ [
[transforms.Pad(2), transforms.RandomCrop(28)], [transforms.Pad(2), transforms.RandomCrop(28)],
[lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)], [lambda x: 2.0 * x, transforms.Pad(2), transforms.RandomCrop(28)],
[transforms.Pad(2), lambda x: 2.0 * x, transforms.RandomCrop(28)],
], ],
) )
def test_ctor(self, transform_cls, trfms): def test_ctor(self, transform_cls, trfms):
...@@ -1339,7 +1340,7 @@ class TestScaleJitter: ...@@ -1339,7 +1340,7 @@ class TestScaleJitter:
n_samples = 5 n_samples = 5
for _ in range(n_samples): for _ in range(n_samples):
params = transform._get_params(sample) params = transform._get_params([sample])
assert "size" in params assert "size" in params
size = params["size"] size = params["size"]
...@@ -1386,7 +1387,7 @@ class TestRandomShortestSize: ...@@ -1386,7 +1387,7 @@ class TestRandomShortestSize:
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size) transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size)
sample = mocker.MagicMock(spec=features.Image, num_channels=3, spatial_size=spatial_size) sample = mocker.MagicMock(spec=features.Image, num_channels=3, spatial_size=spatial_size)
params = transform._get_params(sample) params = transform._get_params([sample])
assert "size" in params assert "size" in params
size = params["size"] size = params["size"]
...@@ -1554,13 +1555,13 @@ class TestFixedSizeCrop: ...@@ -1554,13 +1555,13 @@ class TestFixedSizeCrop:
transform = transforms.FixedSizeCrop(size=crop_size) transform = transforms.FixedSizeCrop(size=crop_size)
sample = dict( flat_inputs = [
image=make_image(size=spatial_size, color_space=features.ColorSpace.RGB), make_image(size=spatial_size, color_space=features.ColorSpace.RGB),
bounding_boxes=make_bounding_box( make_bounding_box(
format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape format=features.BoundingBoxFormat.XYXY, spatial_size=spatial_size, extra_dims=batch_shape
), ),
) ]
params = transform._get_params(sample) params = transform._get_params(flat_inputs)
assert params["needs_crop"] assert params["needs_crop"]
assert params["height"] <= crop_size[0] assert params["height"] <= crop_size[0]
...@@ -1759,7 +1760,7 @@ class TestRandomResize: ...@@ -1759,7 +1760,7 @@ class TestRandomResize:
transform = transforms.RandomResize(min_size=min_size, max_size=max_size) transform = transforms.RandomResize(min_size=min_size, max_size=max_size)
for _ in range(10): for _ in range(10):
params = transform._get_params(None) params = transform._get_params([])
assert isinstance(params["size"], list) and len(params["size"]) == 1 assert isinstance(params["size"], list) and len(params["size"]) == 1
size = params["size"][0] size = params["size"][0]
......
...@@ -639,7 +639,7 @@ class TestContainerTransforms: ...@@ -639,7 +639,7 @@ class TestContainerTransforms:
prototype_transform = prototype_transforms.RandomApply( prototype_transform = prototype_transforms.RandomApply(
[ [
prototype_transforms.Resize(256), prototype_transforms.Resize(256),
legacy_transforms.CenterCrop(224), prototype_transforms.CenterCrop(224),
], ],
p=p, p=p,
) )
......
...@@ -45,8 +45,8 @@ class RandomErasing(_RandomApplyTransform): ...@@ -45,8 +45,8 @@ class RandomErasing(_RandomApplyTransform):
self._log_ratio = torch.log(torch.tensor(self.ratio)) self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(sample) img_c, img_h, img_w = query_chw(flat_inputs)
if isinstance(self.value, (int, float)): if isinstance(self.value, (int, float)):
value = [self.value] value = [self.value]
...@@ -107,13 +107,13 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -107,13 +107,13 @@ class _BaseMixupCutmix(_RandomApplyTransform):
self.alpha = alpha self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
def _check_inputs(self, sample: Any) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not ( if not (
has_any(sample, features.Image, features.Video, features.is_simple_tensor) has_any(flat_inputs, features.Image, features.Video, features.is_simple_tensor)
and has_any(sample, features.OneHotLabel) and has_any(flat_inputs, features.OneHotLabel)
): ):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(sample, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label): if has_any(flat_inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
raise TypeError( raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
) )
...@@ -127,7 +127,7 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -127,7 +127,7 @@ class _BaseMixupCutmix(_RandomApplyTransform):
class RandomMixup(_BaseMixupCutmix): class RandomMixup(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) return dict(lam=float(self._dist.sample(())))
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
...@@ -150,10 +150,10 @@ class RandomMixup(_BaseMixupCutmix): ...@@ -150,10 +150,10 @@ class RandomMixup(_BaseMixupCutmix):
class RandomCutmix(_BaseMixupCutmix): class RandomCutmix(_BaseMixupCutmix):
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) lam = float(self._dist.sample(()))
H, W = query_spatial_size(sample) H, W = query_spatial_size(flat_inputs)
r_x = torch.randint(W, ()) r_x = torch.randint(W, ())
r_y = torch.randint(H, ()) r_y = torch.randint(H, ())
...@@ -344,9 +344,9 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -344,9 +344,9 @@ class SimpleCopyPaste(_RandomApplyTransform):
c3 += 1 c3 += 1
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
flat_sample, spec = tree_flatten(inputs) flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
images, targets = self._extract_image_targets(flat_sample) images, targets = self._extract_image_targets(flat_inputs)
# images = [t1, t2, ..., tN] # images = [t1, t2, ..., tN]
# Let's define paste_images as shifted list of input images # Let's define paste_images as shifted list of input images
...@@ -384,6 +384,6 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -384,6 +384,6 @@ class SimpleCopyPaste(_RandomApplyTransform):
output_targets.append(output_target) output_targets.append(output_target)
# Insert updated images and targets into input flat_sample # Insert updated images and targets into input flat_sample
self._insert_outputs(flat_sample, output_images, output_targets) self._insert_outputs(flat_inputs, output_images, output_targets)
return tree_unflatten(flat_sample, spec) return tree_unflatten(flat_inputs, spec)
...@@ -4,7 +4,7 @@ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeV ...@@ -4,7 +4,7 @@ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeV
import PIL.Image import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_spatial_size from torchvision.prototype.transforms.functional._meta import get_spatial_size
...@@ -31,16 +31,17 @@ class _AutoAugmentBase(Transform): ...@@ -31,16 +31,17 @@ class _AutoAugmentBase(Transform):
key = keys[int(torch.randint(len(keys), ()))] key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key] return key, dct[key]
def _extract_image_or_video( def _flatten_and_extract_image_or_video(
self, self,
sample: Any, inputs: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask), unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[int, Union[features.ImageType, features.VideoType]]: ) -> Tuple[Tuple[List[Any], TreeSpec, int], Union[features.ImageType, features.VideoType]]:
sample_flat, _ = tree_flatten(sample) flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
image_or_videos = [] image_or_videos = []
for id, inpt in enumerate(sample_flat): for idx, inpt in enumerate(flat_inputs):
if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)): if _isinstance(inpt, (features.Image, PIL.Image.Image, features.is_simple_tensor, features.Video)):
image_or_videos.append((id, inpt)) image_or_videos.append((idx, inpt))
elif isinstance(inpt, unsupported_types): elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
...@@ -51,12 +52,18 @@ class _AutoAugmentBase(Transform): ...@@ -51,12 +52,18 @@ class _AutoAugmentBase(Transform):
f"Auto augment transformations are only properly defined for a single image or video, " f"Auto augment transformations are only properly defined for a single image or video, "
f"but found {len(image_or_videos)}." f"but found {len(image_or_videos)}."
) )
return image_or_videos[0]
def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any: idx, image_or_video = image_or_videos[0]
sample_flat, spec = tree_flatten(sample) return (flat_inputs, spec, idx), image_or_video
sample_flat[id] = item
return tree_unflatten(sample_flat, spec) def _unflatten_and_insert_image_or_video(
self,
flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int],
image_or_video: Union[features.ImageType, features.VideoType],
) -> Any:
flat_inputs, spec, idx = flat_inputs_with_spec
flat_inputs[idx] = image_or_video
return tree_unflatten(flat_inputs, spec)
def _apply_image_or_video_transform( def _apply_image_or_video_transform(
self, self,
...@@ -275,9 +282,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -275,9 +282,7 @@ class AutoAugment(_AutoAugmentBase):
raise ValueError(f"The provided policy {policy} is not recognized.") raise ValueError(f"The provided policy {policy} is not recognized.")
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
id, image_or_video = self._extract_image_or_video(sample)
height, width = get_spatial_size(image_or_video) height, width = get_spatial_size(image_or_video)
policy = self._policies[int(torch.randint(len(self._policies), ()))] policy = self._policies[int(torch.randint(len(self._policies), ()))]
...@@ -300,7 +305,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -300,7 +305,7 @@ class AutoAugment(_AutoAugmentBase):
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
return self._put_into_sample(sample, id, image_or_video) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class RandAugment(_AutoAugmentBase): class RandAugment(_AutoAugmentBase):
...@@ -346,9 +351,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -346,9 +351,7 @@ class RandAugment(_AutoAugmentBase):
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
id, image_or_video = self._extract_image_or_video(sample)
height, width = get_spatial_size(image_or_video) height, width = get_spatial_size(image_or_video)
for _ in range(self.num_ops): for _ in range(self.num_ops):
...@@ -364,7 +367,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -364,7 +367,7 @@ class RandAugment(_AutoAugmentBase):
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
return self._put_into_sample(sample, id, image_or_video) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class TrivialAugmentWide(_AutoAugmentBase): class TrivialAugmentWide(_AutoAugmentBase):
...@@ -400,9 +403,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -400,9 +403,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
id, image_or_video = self._extract_image_or_video(sample)
height, width = get_spatial_size(image_or_video) height, width = get_spatial_size(image_or_video)
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
...@@ -418,7 +419,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -418,7 +419,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
image_or_video = self._apply_image_or_video_transform( image_or_video = self._apply_image_or_video_transform(
image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
) )
return self._put_into_sample(sample, id, image_or_video) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video)
class AugMix(_AutoAugmentBase): class AugMix(_AutoAugmentBase):
...@@ -471,8 +472,7 @@ class AugMix(_AutoAugmentBase): ...@@ -471,8 +472,7 @@ class AugMix(_AutoAugmentBase):
return torch._sample_dirichlet(params) return torch._sample_dirichlet(params)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
id, orig_image_or_video = self._extract_image_or_video(sample)
height, width = get_spatial_size(orig_image_or_video) height, width = get_spatial_size(orig_image_or_video)
if isinstance(orig_image_or_video, torch.Tensor): if isinstance(orig_image_or_video, torch.Tensor):
...@@ -525,4 +525,4 @@ class AugMix(_AutoAugmentBase): ...@@ -525,4 +525,4 @@ class AugMix(_AutoAugmentBase):
elif isinstance(orig_image_or_video, PIL.Image.Image): elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix) mix = F.to_image_pil(mix)
return self._put_into_sample(sample, id, mix) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix)
import collections.abc import collections.abc
from typing import Any, Dict, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -53,7 +53,7 @@ class ColorJitter(Transform): ...@@ -53,7 +53,7 @@ class ColorJitter(Transform):
def _generate_value(left: float, right: float) -> float: def _generate_value(left: float, right: float) -> float:
return float(torch.distributions.Uniform(left, right).sample()) return float(torch.distributions.Uniform(left, right).sample())
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
fn_idx = torch.randperm(4) fn_idx = torch.randperm(4)
b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1]) b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1])
...@@ -99,8 +99,8 @@ class RandomPhotometricDistort(Transform): ...@@ -99,8 +99,8 @@ class RandomPhotometricDistort(Transform):
self.saturation = saturation self.saturation = saturation
self.p = p self.p = p
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_channels, *_ = query_chw(sample) num_channels, *_ = query_chw(flat_inputs)
return dict( return dict(
zip( zip(
["brightness", "contrast1", "saturation", "hue", "contrast2"], ["brightness", "contrast1", "saturation", "hue", "contrast2"],
......
import warnings import warnings
from typing import Any, Dict, Union from typing import Any, Dict, List, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -79,8 +79,8 @@ class RandomGrayscale(_RandomApplyTransform): ...@@ -79,8 +79,8 @@ class RandomGrayscale(_RandomApplyTransform):
super().__init__(p=p) super().__init__(p=p)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_input_channels, *_ = query_chw(sample) num_input_channels, *_ = query_chw(flat_inputs)
return dict(num_input_channels=num_input_channels) return dict(num_input_channels=num_input_channels)
def _transform( def _transform(
......
...@@ -104,8 +104,8 @@ class RandomResizedCrop(Transform): ...@@ -104,8 +104,8 @@ class RandomResizedCrop(Transform):
self._log_ratio = torch.log(torch.tensor(self.ratio)) self._log_ratio = torch.log(torch.tensor(self.ratio))
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(sample) height, width = query_spatial_size(flat_inputs)
area = height * width area = height * width
log_ratio = self._log_ratio log_ratio = self._log_ratio
...@@ -184,8 +184,8 @@ class FiveCrop(Transform): ...@@ -184,8 +184,8 @@ class FiveCrop(Transform):
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]: ) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
return F.five_crop(inpt, self.size) return F.five_crop(inpt, self.size)
def _check_inputs(self, sample: Any) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(sample, features.BoundingBox, features.Mask): if has_any(flat_inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
...@@ -201,8 +201,8 @@ class TenCrop(Transform): ...@@ -201,8 +201,8 @@ class TenCrop(Transform):
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def _check_inputs(self, sample: Any) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if has_any(sample, features.BoundingBox, features.Mask): if has_any(flat_inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
def _transform( def _transform(
...@@ -256,8 +256,8 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -256,8 +256,8 @@ class RandomZoomOut(_RandomApplyTransform):
if side_range[0] < 1.0 or side_range[0] > side_range[1]: if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.") raise ValueError(f"Invalid canvas side range provided {side_range}.")
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(sample) orig_h, orig_w = query_spatial_size(flat_inputs)
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0]) r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r) canvas_width = int(orig_w * r)
...@@ -299,7 +299,7 @@ class RandomRotation(Transform): ...@@ -299,7 +299,7 @@ class RandomRotation(Transform):
self.center = center self.center = center
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
return dict(angle=angle) return dict(angle=angle)
...@@ -355,8 +355,8 @@ class RandomAffine(Transform): ...@@ -355,8 +355,8 @@ class RandomAffine(Transform):
self.center = center self.center = center
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(sample) height, width = query_spatial_size(flat_inputs)
angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item())
if self.translate is not None: if self.translate is not None:
...@@ -417,8 +417,8 @@ class RandomCrop(Transform): ...@@ -417,8 +417,8 @@ class RandomCrop(Transform):
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
padded_height, padded_width = query_spatial_size(sample) padded_height, padded_width = query_spatial_size(flat_inputs)
if self.padding is not None: if self.padding is not None:
pad_left, pad_right, pad_top, pad_bottom = self.padding pad_left, pad_right, pad_top, pad_bottom = self.padding
...@@ -505,8 +505,8 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -505,8 +505,8 @@ class RandomPerspective(_RandomApplyTransform):
self.interpolation = interpolation self.interpolation = interpolation
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(sample) height, width = query_spatial_size(flat_inputs)
distortion_scale = self.distortion_scale distortion_scale = self.distortion_scale
...@@ -559,8 +559,8 @@ class ElasticTransform(Transform): ...@@ -559,8 +559,8 @@ class ElasticTransform(Transform):
self.interpolation = interpolation self.interpolation = interpolation
self.fill = _setup_fill_arg(fill) self.fill = _setup_fill_arg(fill)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = list(query_spatial_size(sample)) size = list(query_spatial_size(flat_inputs))
dx = torch.rand([1, 1] + size) * 2 - 1 dx = torch.rand([1, 1] + size) * 2 - 1
if self.sigma[0] > 0.0: if self.sigma[0] > 0.0:
...@@ -614,20 +614,20 @@ class RandomIoUCrop(Transform): ...@@ -614,20 +614,20 @@ class RandomIoUCrop(Transform):
self.options = sampler_options self.options = sampler_options
self.trials = trials self.trials = trials
def _check_inputs(self, sample: Any) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not ( if not (
has_all(sample, features.BoundingBox) has_all(flat_inputs, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image, features.is_simple_tensor) and has_any(flat_inputs, PIL.Image.Image, features.Image, features.is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel) and has_any(flat_inputs, features.Label, features.OneHotLabel)
): ):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks." "BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks."
) )
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_spatial_size(sample) orig_h, orig_w = query_spatial_size(flat_inputs)
bboxes = query_bounding_box(sample) bboxes = query_bounding_box(flat_inputs)
while True: while True:
# sample an option # sample an option
...@@ -712,8 +712,8 @@ class ScaleJitter(Transform): ...@@ -712,8 +712,8 @@ class ScaleJitter(Transform):
self.interpolation = interpolation self.interpolation = interpolation
self.antialias = antialias self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(sample) orig_height, orig_width = query_spatial_size(flat_inputs)
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
...@@ -740,8 +740,8 @@ class RandomShortestSize(Transform): ...@@ -740,8 +740,8 @@ class RandomShortestSize(Transform):
self.interpolation = interpolation self.interpolation = interpolation
self.antialias = antialias self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_height, orig_width = query_spatial_size(sample) orig_height, orig_width = query_spatial_size(flat_inputs)
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
...@@ -771,20 +771,22 @@ class FixedSizeCrop(Transform): ...@@ -771,20 +771,22 @@ class FixedSizeCrop(Transform):
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _check_inputs(self, sample: Any) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not has_any(sample, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video): if not has_any(flat_inputs, PIL.Image.Image, features.Image, features.is_simple_tensor, features.Video):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video." f"{type(self).__name__}() requires input sample to contain an tensor or PIL image or a Video."
) )
if has_any(sample, features.BoundingBox) and not has_any(sample, features.Label, features.OneHotLabel): if has_any(flat_inputs, features.BoundingBox) and not has_any(
flat_inputs, features.Label, features.OneHotLabel
):
raise TypeError( raise TypeError(
f"If a BoundingBox is contained in the input sample, " f"If a BoundingBox is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel." f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
) )
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
height, width = query_spatial_size(sample) height, width = query_spatial_size(flat_inputs)
new_height = min(height, self.crop_height) new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width) new_width = min(width, self.crop_width)
...@@ -798,7 +800,7 @@ class FixedSizeCrop(Transform): ...@@ -798,7 +800,7 @@ class FixedSizeCrop(Transform):
left = int(offset_width * r) left = int(offset_width * r)
try: try:
bounding_boxes = query_bounding_box(sample) bounding_boxes = query_bounding_box(flat_inputs)
except ValueError: except ValueError:
bounding_boxes = None bounding_boxes = None
...@@ -874,7 +876,7 @@ class RandomResize(Transform): ...@@ -874,7 +876,7 @@ class RandomResize(Transform):
self.interpolation = interpolation self.interpolation = interpolation
self.antialias = antialias self.antialias = antialias
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
size = int(torch.randint(self.min_size, self.max_size, ())) size = int(torch.randint(self.min_size, self.max_size, ()))
return dict(size=[size]) return dict(size=[size])
......
import functools import functools
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, Sequence, Type, Union from typing import Any, Callable, Dict, List, Sequence, Type, Union
import PIL.Image import PIL.Image
...@@ -134,7 +134,7 @@ class GaussianBlur(Transform): ...@@ -134,7 +134,7 @@ class GaussianBlur(Transform):
self.sigma = _setup_float_or_seq(sigma, "sigma", 2) self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
return dict(sigma=[sigma, sigma]) return dict(sigma=[sigma, sigma])
...@@ -167,8 +167,8 @@ class RemoveSmallBoundingBoxes(Transform): ...@@ -167,8 +167,8 @@ class RemoveSmallBoundingBoxes(Transform):
super().__init__() super().__init__()
self.min_size = min_size self.min_size = min_size
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
bounding_box = query_bounding_box(sample) bounding_box = query_bounding_box(flat_inputs)
# TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to # TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to
# be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH # be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH
......
import enum import enum
from typing import Any, Callable, Dict, Tuple, Type, Union from typing import Any, Callable, Dict, List, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -23,27 +23,27 @@ class Transform(nn.Module): ...@@ -23,27 +23,27 @@ class Transform(nn.Module):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
def _check_inputs(self, sample: Any) -> None: def _check_inputs(self, flat_inputs: List[Any]) -> None:
pass pass
def _get_params(self, sample: Any) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict() return dict()
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
raise NotImplementedError raise NotImplementedError
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
self._check_inputs(sample) self._check_inputs(flat_inputs)
params = self._get_params(sample) params = self._get_params(flat_inputs)
flat_inputs, spec = tree_flatten(sample)
flat_outputs = [ flat_outputs = [
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs for inpt in flat_inputs
] ]
return tree_unflatten(flat_outputs, spec) return tree_unflatten(flat_outputs, spec)
def extra_repr(self) -> str: def extra_repr(self) -> str:
...@@ -73,18 +73,19 @@ class _RandomApplyTransform(Transform): ...@@ -73,18 +73,19 @@ class _RandomApplyTransform(Transform):
# early afterwards in case the random check triggers. The same result could be achieved by calling # early afterwards in case the random check triggers. The same result could be achieved by calling
# `super().forward()` after the random check, but that would call `self._check_inputs` twice. # `super().forward()` after the random check, but that would call `self._check_inputs` twice.
sample = inputs if len(inputs) > 1 else inputs[0] inputs = inputs if len(inputs) > 1 else inputs[0]
flat_inputs, spec = tree_flatten(inputs)
self._check_inputs(sample) self._check_inputs(flat_inputs)
if torch.rand(1) >= self.p: if torch.rand(1) >= self.p:
return sample return inputs
params = self._get_params(sample) params = self._get_params(flat_inputs)
flat_inputs, spec = tree_flatten(sample)
flat_outputs = [ flat_outputs = [
self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt self._transform(inpt, params) if _isinstance(inpt, self._transformed_types) else inpt
for inpt in flat_inputs for inpt in flat_inputs
] ]
return tree_unflatten(flat_outputs, spec) return tree_unflatten(flat_outputs, spec)
import functools import functools
import numbers import numbers
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, Sequence, Tuple, Type, Union from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, Union
import PIL.Image import PIL.Image
from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.features._feature import FillType from torchvision.prototype.features._feature import FillType
...@@ -73,9 +72,8 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", ...@@ -73,9 +72,8 @@ def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect",
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
def query_bounding_box(sample: Any) -> features.BoundingBox: def query_bounding_box(flat_inputs: List[Any]) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample) bounding_boxes = {inpt for inpt in flat_inputs if isinstance(inpt, features.BoundingBox)}
bounding_boxes = {item for item in flat_sample if isinstance(item, features.BoundingBox)}
if not bounding_boxes: if not bounding_boxes:
raise TypeError("No bounding box was found in the sample") raise TypeError("No bounding box was found in the sample")
elif len(bounding_boxes) > 1: elif len(bounding_boxes) > 1:
...@@ -83,12 +81,11 @@ def query_bounding_box(sample: Any) -> features.BoundingBox: ...@@ -83,12 +81,11 @@ def query_bounding_box(sample: Any) -> features.BoundingBox:
return bounding_boxes.pop() return bounding_boxes.pop()
def query_chw(sample: Any) -> Tuple[int, int, int]: def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
flat_sample, _ = tree_flatten(sample)
chws = { chws = {
tuple(get_dimensions(item)) tuple(get_dimensions(inpt))
for item in flat_sample for inpt in flat_inputs
if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item) if isinstance(inpt, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(inpt)
} }
if not chws: if not chws:
raise TypeError("No image or video was found in the sample") raise TypeError("No image or video was found in the sample")
...@@ -98,13 +95,12 @@ def query_chw(sample: Any) -> Tuple[int, int, int]: ...@@ -98,13 +95,12 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
return c, h, w return c, h, w
def query_spatial_size(sample: Any) -> Tuple[int, int]: def query_spatial_size(flat_inputs: List[Any]) -> Tuple[int, int]:
flat_sample, _ = tree_flatten(sample)
sizes = { sizes = {
tuple(get_spatial_size(item)) tuple(get_spatial_size(inpt))
for item in flat_sample for inpt in flat_inputs
if isinstance(item, (features.Image, PIL.Image.Image, features.Video, features.Mask, features.BoundingBox)) if isinstance(inpt, (features.Image, PIL.Image.Image, features.Video, features.Mask, features.BoundingBox))
or features.is_simple_tensor(item) or features.is_simple_tensor(inpt)
} }
if not sizes: if not sizes:
raise TypeError("No image, video, mask or bounding box was found in the sample") raise TypeError("No image, video, mask or bounding box was found in the sample")
...@@ -121,19 +117,17 @@ def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], boo ...@@ -121,19 +117,17 @@ def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], boo
return False return False
def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample) for inpt in flat_inputs:
for obj in flat_sample: if _isinstance(inpt, types_or_checks):
if _isinstance(obj, types_or_checks):
return True return True
return False return False
def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
flat_sample, _ = tree_flatten(sample)
for type_or_check in types_or_checks: for type_or_check in types_or_checks:
for obj in flat_sample: for inpt in flat_inputs:
if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
break break
else: else:
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment