Unverified Commit b4686f2b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Fully exhaust datapipes that are needed to construct a dataset (#6076)

parent 2c19af37
...@@ -13,6 +13,7 @@ from torchdata.datapipes.iter import ( ...@@ -13,6 +13,7 @@ from torchdata.datapipes.iter import (
LineReader, LineReader,
Mapper, Mapper,
) )
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, GDriveResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -114,6 +115,9 @@ class CUB200(Dataset): ...@@ -114,6 +115,9 @@ class CUB200(Dataset):
else: else:
return None return None
def _2011_extract_file_name(self, rel_posix_path: str) -> str:
return rel_posix_path.rsplit("/", maxsplit=1)[1]
def _2011_filter_split(self, row: List[str]) -> bool: def _2011_filter_split(self, row: List[str]) -> bool:
_, split_id = row _, split_id = row
return { return {
...@@ -185,17 +189,16 @@ class CUB200(Dataset): ...@@ -185,17 +189,16 @@ class CUB200(Dataset):
) )
image_files_dp = CSVParser(image_files_dp, dialect="cub200") image_files_dp = CSVParser(image_files_dp, dialect="cub200")
image_files_map = dict( image_files_dp = Mapper(image_files_dp, self._2011_extract_file_name, input_col=1)
(image_id, rel_posix_path.rsplit("/", maxsplit=1)[1]) for image_id, rel_posix_path in image_files_dp image_files_map = IterToMapConverter(image_files_dp)
)
split_dp = CSVParser(split_dp, dialect="cub200") split_dp = CSVParser(split_dp, dialect="cub200")
split_dp = Filter(split_dp, self._2011_filter_split) split_dp = Filter(split_dp, self._2011_filter_split)
split_dp = Mapper(split_dp, getitem(0)) split_dp = Mapper(split_dp, getitem(0))
split_dp = Mapper(split_dp, image_files_map.get) split_dp = Mapper(split_dp, image_files_map.__getitem__)
bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200") bounding_boxes_dp = CSVParser(bounding_boxes_dp, dialect="cub200")
bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.get, input_col=0) bounding_boxes_dp = Mapper(bounding_boxes_dp, image_files_map.__getitem__, input_col=0)
anns_dp = IterKeyZipper( anns_dp = IterKeyZipper(
bounding_boxes_dp, bounding_boxes_dp,
......
import enum import enum
import functools
import pathlib import pathlib
import re import re
from typing import Any, BinaryIO, cast, Dict, List, Match, Optional, Tuple, Union
from typing import Any, BinaryIO, cast, Dict, Iterator, List, Match, Optional, Tuple, Union
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
Demultiplexer, Demultiplexer,
...@@ -14,6 +14,7 @@ from torchdata.datapipes.iter import ( ...@@ -14,6 +14,7 @@ from torchdata.datapipes.iter import (
Mapper, Mapper,
TarArchiveLoader, TarArchiveLoader,
) )
from torchdata.datapipes.map import IterToMapConverter
from torchvision.prototype.datasets.utils import Dataset, ManualDownloadResource, OnlineResource from torchvision.prototype.datasets.utils import Dataset, ManualDownloadResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
getitem, getitem,
...@@ -47,6 +48,28 @@ class ImageNetDemux(enum.IntEnum): ...@@ -47,6 +48,28 @@ class ImageNetDemux(enum.IntEnum):
LABEL = 1 LABEL = 1
class CategoryAndWordNetIDExtractor(IterDataPipe):
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP = {
"n03126707": "construction crane",
"n03710721": "tank suit",
}
def __init__(self, datapipe: IterDataPipe[Tuple[str, BinaryIO]]) -> None:
self.datapipe = datapipe
def __iter__(self) -> Iterator[Tuple[str, str]]:
for _, stream in self.datapipe:
synsets = read_mat(stream, squeeze_me=True)["synsets"]
for _, wnid, category, _, num_children, *_ in synsets:
if num_children > 0:
# we are looking at a superclass that has no direct instance
continue
yield self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid
@register_dataset(NAME) @register_dataset(NAME)
class ImageNet(Dataset): class ImageNet(Dataset):
""" """
...@@ -110,25 +133,6 @@ class ImageNet(Dataset): ...@@ -110,25 +133,6 @@ class ImageNet(Dataset):
"ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL, "ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL,
}.get(pathlib.Path(data[0]).name) }.get(pathlib.Path(data[0]).name)
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
_WNID_MAP = {
"n03126707": "construction crane",
"n03710721": "tank suit",
}
def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]:
synsets = read_mat(data[1], squeeze_me=True)["synsets"]
return [
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
for _, wnid, category, _, num_children, *_ in synsets
# if num_children > 0, we are looking at a superclass that has no direct instance
if num_children == 0
]
def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str:
return wnids[int(imagenet_label) - 1]
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG") _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
def _val_test_image_key(self, path: pathlib.Path) -> int: def _val_test_image_key(self, path: pathlib.Path) -> int:
...@@ -172,12 +176,15 @@ class ImageNet(Dataset): ...@@ -172,12 +176,15 @@ class ImageNet(Dataset):
devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
) )
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) # We cannot use self._wnids here, since we use a different order than the dataset
_, wnids = zip(*next(iter(meta_dp))) meta_dp = CategoryAndWordNetIDExtractor(meta_dp)
wnid_dp = Mapper(meta_dp, getitem(1))
wnid_dp = Enumerator(wnid_dp, 1)
wnid_map = IterToMapConverter(wnid_dp)
label_dp = LineReader(label_dp, decode=True, return_path=False) label_dp = LineReader(label_dp, decode=True, return_path=False)
# We cannot use self._wnids here, since we use a different order than the dataset label_dp = Mapper(label_dp, int)
label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) label_dp = Mapper(label_dp, wnid_map.__getitem__)
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
label_dp = hint_shuffling(label_dp) label_dp = hint_shuffling(label_dp)
label_dp = hint_sharding(label_dp) label_dp = hint_sharding(label_dp)
...@@ -209,8 +216,8 @@ class ImageNet(Dataset): ...@@ -209,8 +216,8 @@ class ImageNet(Dataset):
devkit_dp = resources[1].load(self._root) devkit_dp = resources[1].load(self._root)
meta_dp = Filter(devkit_dp, self._filter_meta) meta_dp = Filter(devkit_dp, self._filter_meta)
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) meta_dp = CategoryAndWordNetIDExtractor(meta_dp)
categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) categories_and_wnids = cast(List[Tuple[str, ...]], list(meta_dp))
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
return categories_and_wnids return categories_and_wnids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment