"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "f9348a04a287b20f9cb0eb522c3c1dfd0b4391cd"
Unverified Commit 2a0eea82 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

replace query_recursively with pytree implementation (#6434)



* replace query_recursively with pytree implementation

* simplify
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 961d97b2
...@@ -3,28 +3,23 @@ from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, T ...@@ -3,28 +3,23 @@ from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, T
import PIL.Image import PIL.Image
import torch import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.utils._internal import query_recursively
from torchvision.transforms.autoaugment import AutoAugmentPolicy from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from ._utils import get_image_dimensions from ._utils import get_image_dimensions, is_simple_tensor
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any: def _put_into_sample(sample: Any, id: int, item: Any) -> Any:
if not id: sample_flat, spec = tree_flatten(sample)
return item sample_flat[id] = item
return tree_unflatten(sample_flat, spec)
parent = sample
for key in id[:-1]:
parent = parent[key]
parent[id[-1]] = item
return sample
class _AutoAugmentBase(Transform): class _AutoAugmentBase(Transform):
...@@ -47,18 +42,15 @@ class _AutoAugmentBase(Transform): ...@@ -47,18 +42,15 @@ class _AutoAugmentBase(Transform):
self, self,
sample: Any, sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask), unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]: ) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
def fn( sample_flat, _ = tree_flatten(sample)
id: Tuple[Any, ...], inpt: Any images = []
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: for id, inpt in enumerate(sample_flat):
if type(inpt) in {torch.Tensor, features.Image} or isinstance(inpt, PIL.Image.Image): if isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt):
return id, inpt images.append((id, inpt))
elif isinstance(inpt, unsupported_types): elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
else:
return None
images = list(query_recursively(fn, sample))
if not images: if not images:
raise TypeError("Found no image in the sample.") raise TypeError("Found no image in the sample.")
if len(images) > 1: if len(images) > 1:
......
...@@ -3,7 +3,7 @@ import difflib ...@@ -3,7 +3,7 @@ import difflib
import io import io
import mmap import mmap
import platform import platform
from typing import Any, BinaryIO, Callable, Collection, Iterator, Optional, Sequence, Tuple, TypeVar, Union from typing import BinaryIO, Callable, Collection, Sequence, TypeVar, Union
import numpy as np import numpy as np
import torch import torch
...@@ -14,7 +14,6 @@ __all__ = [ ...@@ -14,7 +14,6 @@ __all__ = [
"add_suggestion", "add_suggestion",
"fromfile", "fromfile",
"ReadOnlyTensorBuffer", "ReadOnlyTensorBuffer",
"query_recursively",
] ]
...@@ -125,20 +124,3 @@ class ReadOnlyTensorBuffer: ...@@ -125,20 +124,3 @@ class ReadOnlyTensorBuffer:
cursor = self.tell() cursor = self.tell()
offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR) offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR)
return self._memory[slice(cursor, self.seek(offset, whence))].tobytes() return self._memory[slice(cursor, self.seek(offset, whence))].tobytes()
def query_recursively(
fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = ()
) -> Iterator[D]:
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
for idx, item in enumerate(obj):
yield from query_recursively(fn, item, id=(*id, idx))
elif isinstance(obj, collections.abc.Mapping):
for key, item in obj.items():
yield from query_recursively(fn, item, id=(*id, key))
else:
result = fn(id, obj)
if result is not None:
yield result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment