Unverified Commit 408c9bea authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make prototype datasets traversable (#4950)


Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 59baae99
...@@ -2,6 +2,7 @@ import io ...@@ -2,6 +2,7 @@ import io
import builtin_dataset_mocks import builtin_dataset_mocks
import pytest import pytest
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets, features from torchvision.prototype import datasets, features
from torchvision.prototype.datasets._api import DEFAULT_DECODER from torchvision.prototype.datasets._api import DEFAULT_DECODER
...@@ -83,6 +84,10 @@ class TestCommon: ...@@ -83,6 +84,10 @@ class TestCommon:
if not any(isinstance(value, features.Feature) for value in sample.values()): if not any(isinstance(value, features.Feature) for value in sample.values()):
raise AssertionError("The sample contained no feature.") raise AssertionError("The sample contained no feature.")
@dataset_parametrization()
def test_traversable(self, dataset, mock_info):
traverse(dataset)
class TestQMNIST: class TestQMNIST:
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
import enum import enum
import functools
import gzip import gzip
import io import io
import lzma import lzma
...@@ -101,35 +102,37 @@ class Enumerator(IterDataPipe[Tuple[int, D]]): ...@@ -101,35 +102,37 @@ class Enumerator(IterDataPipe[Tuple[int, D]]):
yield from enumerate(self.datapipe, self.start) yield from enumerate(self.datapipe, self.start)
def _getitem_closure(obj: Any, *, items: Tuple[Any, ...]) -> Any:
for item in items:
obj = obj[item]
return obj
def getitem(*items: Any) -> Callable[[Any], Any]: def getitem(*items: Any) -> Callable[[Any], Any]:
def wrapper(obj: Any) -> Any: return functools.partial(_getitem_closure, items=items)
for item in items:
obj = obj[item]
return obj def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D:
return cast(D, getattr(path, name))
return wrapper
def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:
return getter(pathlib.Path(data[0]))
def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[Tuple[str, Any]], D]: def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[Tuple[str, Any]], D]:
if isinstance(getter, str): if isinstance(getter, str):
name = getter getter = functools.partial(_path_attribute_accessor, name=getter)
def getter(path: pathlib.Path) -> D: return functools.partial(_path_accessor_closure, getter=getter)
return cast(D, getattr(path, name))
def wrapper(data: Tuple[str, Any]) -> D:
return getter(pathlib.Path(data[0])) # type: ignore[operator]
return wrapper def _path_comparator_closure(data: Tuple[str, Any], *, accessor: Callable[[Tuple[str, Any]], D], value: D) -> bool:
return accessor(data) == value
def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -> Callable[[Tuple[str, Any]], bool]: def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -> Callable[[Tuple[str, Any]], bool]:
accessor = path_accessor(getter) return functools.partial(_path_comparator_closure, accessor=path_accessor(getter), value=value)
def wrapper(data: Tuple[str, Any]) -> bool:
return accessor(data) == value
return wrapper
class CompressionType(enum.Enum): class CompressionType(enum.Enum):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment