Unverified Commit 40a0ab79 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove SampleQuery from prototype datasets (#5991)

parent e6edcef4
from . import _internal # usort: skip from . import _internal # usort: skip
from ._dataset import Dataset from ._dataset import Dataset
from ._query import SampleQuery
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource
import collections.abc
from typing import Any, Callable, Iterator, Optional, Tuple, TypeVar, cast
from torchvision.prototype.features import BoundingBox, Image
T = TypeVar("T")
class SampleQuery:
def __init__(self, sample: Any) -> None:
self.sample = sample
@staticmethod
def _query_recursively(sample: Any, fn: Callable[[Any], Optional[T]]) -> Iterator[T]:
if isinstance(sample, (collections.abc.Sequence, collections.abc.Mapping)):
for item in sample.values() if isinstance(sample, collections.abc.Mapping) else sample:
yield from SampleQuery._query_recursively(item, fn)
else:
result = fn(sample)
if result is not None:
yield result
def query(self, fn: Callable[[Any], Optional[T]]) -> T:
results = set(self._query_recursively(self.sample, fn))
if not results:
raise RuntimeError("Query turned up empty.")
elif len(results) > 1:
raise RuntimeError(f"Found more than one result: {results}")
return results.pop()
def image_size(self) -> Tuple[int, int]:
def fn(sample: Any) -> Optional[Tuple[int, int]]:
if isinstance(sample, Image):
return cast(Tuple[int, int], sample.shape[-2:])
elif isinstance(sample, BoundingBox):
return sample.image_size
else:
return None
return self.query(fn)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment