Unverified Commit 408c9bea authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make prototype datasets traversable (#4950)


Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 59baae99
......@@ -2,6 +2,7 @@ import io
import builtin_dataset_mocks
import pytest
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets, features
from torchvision.prototype.datasets._api import DEFAULT_DECODER
......@@ -83,6 +84,10 @@ class TestCommon:
if not any(isinstance(value, features.Feature) for value in sample.values()):
raise AssertionError("The sample contained no feature.")
@dataset_parametrization()
def test_traversable(self, dataset, mock_info):
traverse(dataset)
class TestQMNIST:
@pytest.mark.parametrize(
......
import enum
import functools
import gzip
import io
import lzma
......@@ -101,35 +102,37 @@ class Enumerator(IterDataPipe[Tuple[int, D]]):
yield from enumerate(self.datapipe, self.start)
def _getitem_closure(obj: Any, *, items: Tuple[Any, ...]) -> Any:
for item in items:
obj = obj[item]
return obj
def getitem(*items: Any) -> Callable[[Any], Any]:
def wrapper(obj: Any) -> Any:
for item in items:
obj = obj[item]
return obj
return functools.partial(_getitem_closure, items=items)
def _path_attribute_accessor(path: pathlib.Path, *, name: str) -> D:
return cast(D, getattr(path, name))
return wrapper
def _path_accessor_closure(data: Tuple[str, Any], *, getter: Callable[[pathlib.Path], D]) -> D:
return getter(pathlib.Path(data[0]))
def path_accessor(getter: Union[str, Callable[[pathlib.Path], D]]) -> Callable[[Tuple[str, Any]], D]:
if isinstance(getter, str):
name = getter
getter = functools.partial(_path_attribute_accessor, name=getter)
def getter(path: pathlib.Path) -> D:
return cast(D, getattr(path, name))
return functools.partial(_path_accessor_closure, getter=getter)
def wrapper(data: Tuple[str, Any]) -> D:
return getter(pathlib.Path(data[0])) # type: ignore[operator]
return wrapper
def _path_comparator_closure(data: Tuple[str, Any], *, accessor: Callable[[Tuple[str, Any]], D], value: D) -> bool:
return accessor(data) == value
def path_comparator(getter: Union[str, Callable[[pathlib.Path], D]], value: D) -> Callable[[Tuple[str, Any]], bool]:
accessor = path_accessor(getter)
def wrapper(data: Tuple[str, Any]) -> bool:
return accessor(data) == value
return wrapper
return functools.partial(_path_comparator_closure, accessor=path_accessor(getter), value=value)
class CompressionType(enum.Enum):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment