"vscode:/vscode.git/clone" did not exist on "e4d264e4eb23de860220c87a934f4dfb41879923"
Unverified Commit 2a0eea82 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

replace query_recursively with pytree implementation (#6434)



* replace query_recursively with pytree implementation

* simplify
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 961d97b2
......@@ -3,28 +3,23 @@ from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, T
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.utils._internal import query_recursively
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from ._utils import get_image_dimensions
from ._utils import get_image_dimensions, is_simple_tensor
K = TypeVar("K")
V = TypeVar("V")
def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any:
if not id:
return item
parent = sample
for key in id[:-1]:
parent = parent[key]
parent[id[-1]] = item
return sample
def _put_into_sample(sample: Any, id: int, item: Any) -> Any:
sample_flat, spec = tree_flatten(sample)
sample_flat[id] = item
return tree_unflatten(sample_flat, spec)
class _AutoAugmentBase(Transform):
......@@ -47,18 +42,15 @@ class _AutoAugmentBase(Transform):
self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]:
def fn(
id: Tuple[Any, ...], inpt: Any
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
if type(inpt) in {torch.Tensor, features.Image} or isinstance(inpt, PIL.Image.Image):
return id, inpt
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
sample_flat, _ = tree_flatten(sample)
images = []
for id, inpt in enumerate(sample_flat):
if isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt):
images.append((id, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
else:
return None
images = list(query_recursively(fn, sample))
if not images:
raise TypeError("Found no image in the sample.")
if len(images) > 1:
......
......@@ -3,7 +3,7 @@ import difflib
import io
import mmap
import platform
from typing import Any, BinaryIO, Callable, Collection, Iterator, Optional, Sequence, Tuple, TypeVar, Union
from typing import BinaryIO, Callable, Collection, Sequence, TypeVar, Union
import numpy as np
import torch
......@@ -14,7 +14,6 @@ __all__ = [
"add_suggestion",
"fromfile",
"ReadOnlyTensorBuffer",
"query_recursively",
]
......@@ -125,20 +124,3 @@ class ReadOnlyTensorBuffer:
cursor = self.tell()
offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR)
return self._memory[slice(cursor, self.seek(offset, whence))].tobytes()
def query_recursively(
fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = ()
) -> Iterator[D]:
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
for idx, item in enumerate(obj):
yield from query_recursively(fn, item, id=(*id, idx))
elif isinstance(obj, collections.abc.Mapping):
for key, item in obj.items():
yield from query_recursively(fn, item, id=(*id, key))
else:
result = fn(id, obj)
if result is not None:
yield result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment