Unverified Commit 4d00ae0f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add download functionality to prototype datasets (#5035)

* add download functionality to prototype datasets

* fix annotation

* fix test

* remove iopath

* add comments
parent 4282c9fc
...@@ -29,6 +29,19 @@ __all__ = ["load"] ...@@ -29,6 +29,19 @@ __all__ = ["load"]
DEFAULT_TEST_DECODER = object() DEFAULT_TEST_DECODER = object()
class TestResource(datasets.utils.OnlineResource):
def __init__(self, *, dataset_name, dataset_config, **kwargs):
super().__init__(**kwargs)
self.dataset_name = dataset_name
self.dataset_config = dataset_config
def _download(self, _):
raise pytest.UsageError(
f"Dataset '{self.dataset_name}' requires the file '{self.file_name}' for {self.dataset_config}, "
f"but this file does not exist."
)
class DatasetMocks: class DatasetMocks:
def __init__(self): def __init__(self):
self._mock_data_fns = {} self._mock_data_fns = {}
...@@ -72,7 +85,7 @@ class DatasetMocks: ...@@ -72,7 +85,7 @@ class DatasetMocks:
) )
return mock_info return mock_info
def _get(self, dataset, config): def _get(self, dataset, config, root):
name = dataset.info.name name = dataset.info.name
resources_and_mock_info = self._cache.get((name, config)) resources_and_mock_info = self._cache.get((name, config))
if resources_and_mock_info: if resources_and_mock_info:
...@@ -87,20 +100,12 @@ class DatasetMocks: ...@@ -87,20 +100,12 @@ class DatasetMocks:
f"Did you register the mock data function with `@DatasetMocks.register_mock_data_fn`?" f"Did you register the mock data function with `@DatasetMocks.register_mock_data_fn`?"
) )
root = self._tmp_home / name mock_resources = [
root.mkdir(exist_ok=True) TestResource(dataset_name=name, dataset_config=config, file_name=resource.file_name)
for resource in dataset.resources(config)
]
mock_info = self._parse_mock_info(fakedata_fn(dataset.info, root, config), name=name) mock_info = self._parse_mock_info(fakedata_fn(dataset.info, root, config), name=name)
mock_resources = []
for resource in dataset.resources(config):
path = root / resource.file_name
if not path.exists() and path.is_file():
raise pytest.UsageError(
f"Dataset '{name}' requires the file {path.name} for {config}, but this file does not exist."
)
mock_resources.append(datasets.utils.LocalResource(path))
self._cache[(name, config)] = mock_resources, mock_info self._cache[(name, config)] = mock_resources, mock_info
return mock_resources, mock_info return mock_resources, mock_info
...@@ -109,9 +114,13 @@ class DatasetMocks: ...@@ -109,9 +114,13 @@ class DatasetMocks:
) -> Tuple[IterDataPipe, Dict[str, Any]]: ) -> Tuple[IterDataPipe, Dict[str, Any]]:
dataset = find(name) dataset = find(name)
config = dataset.info.make_config(split=split, **options) config = dataset.info.make_config(split=split, **options)
resources, mock_info = self._get(dataset, config)
root = self._tmp_home / name
root.mkdir(exist_ok=True)
resources, mock_info = self._get(dataset, config, root)
datapipe = dataset._make_datapipe( datapipe = dataset._make_datapipe(
[resource.to_datapipe() for resource in resources], [resource.load(root) for resource in resources],
config=config, config=config,
decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder, decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder,
) )
......
...@@ -211,10 +211,10 @@ class TestDataset: ...@@ -211,10 +211,10 @@ class TestDataset:
pytest.param(make_minimal_dataset_info().default_config, None, id="default"), pytest.param(make_minimal_dataset_info().default_config, None, id="default"),
], ],
) )
def test_to_datapipe_config(self, config, kwarg): def test_load_config(self, config, kwarg):
dataset = self.DatasetMock() dataset = self.DatasetMock()
dataset.to_datapipe("", config=kwarg) dataset.load("", config=kwarg)
dataset.resources.assert_called_with(config) dataset.resources.assert_called_with(config)
...@@ -225,18 +225,19 @@ class TestDataset: ...@@ -225,18 +225,19 @@ class TestDataset:
dependency = "fake_dependency" dependency = "fake_dependency"
dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,))) dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,)))
with pytest.raises(ModuleNotFoundError, match=dependency): with pytest.raises(ModuleNotFoundError, match=dependency):
dataset.to_datapipe("root") dataset.load("root")
def test_resources(self, mocker): def test_resources(self, mocker):
resource_mock = mocker.Mock(spec=["to_datapipe"]) resource_mock = mocker.Mock(spec=["load"])
sentinel = object() sentinel = object()
resource_mock.to_datapipe.return_value = sentinel resource_mock.load.return_value = sentinel
dataset = self.DatasetMock(resources=[resource_mock]) dataset = self.DatasetMock(resources=[resource_mock])
root = "root" root = "root"
dataset.to_datapipe(root) dataset.load(root)
resource_mock.to_datapipe.assert_called_with(root) (call_args, _) = resource_mock.load.call_args
assert call_args[0] == root
(call_args, _) = dataset._make_datapipe.call_args (call_args, _) = dataset._make_datapipe.call_args
assert call_args[0][0] is sentinel assert call_args[0][0] is sentinel
...@@ -245,7 +246,7 @@ class TestDataset: ...@@ -245,7 +246,7 @@ class TestDataset:
dataset = self.DatasetMock() dataset = self.DatasetMock()
sentinel = object() sentinel = object()
dataset.to_datapipe("", decoder=sentinel) dataset.load("", decoder=sentinel)
(_, call_kwargs) = dataset._make_datapipe.call_args (_, call_kwargs) = dataset._make_datapipe.call_args
assert call_kwargs["decoder"] is sentinel assert call_kwargs["decoder"] is sentinel
...@@ -61,16 +61,16 @@ def load( ...@@ -61,16 +61,16 @@ def load(
name: str, name: str,
*, *,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment] decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment]
skip_integrity_check: bool = False,
split: str = "train", split: str = "train",
**options: Any, **options: Any,
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
name = name.lower()
dataset = find(name) dataset = find(name)
if decoder is DEFAULT_DECODER: if decoder is DEFAULT_DECODER:
decoder = DEFAULT_DECODER_MAP.get(dataset.info.type) decoder = DEFAULT_DECODER_MAP.get(dataset.info.type)
config = dataset.info.make_config(split=split, **options) config = dataset.info.make_config(split=split, **options)
root = os.path.join(home(), name) root = os.path.join(home(), dataset.name)
return dataset.to_datapipe(root, config=config, decoder=decoder) return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check)
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
TarArchiveReader,
Shuffler, Shuffler,
Filter, Filter,
IterKeyZipper, IterKeyZipper,
...@@ -38,6 +37,7 @@ class Caltech101(Dataset): ...@@ -38,6 +37,7 @@ class Caltech101(Dataset):
images = HttpResource( images = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926", sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
decompress=True,
) )
anns = HttpResource( anns = HttpResource(
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
...@@ -119,11 +119,9 @@ class Caltech101(Dataset): ...@@ -119,11 +119,9 @@ class Caltech101(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps images_dp, anns_dp = resource_dps
images_dp = TarArchiveReader(images_dp)
images_dp = Filter(images_dp, self._is_not_background_image) images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE) images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
anns_dp = TarArchiveReader(anns_dp)
anns_dp = Filter(anns_dp, self._is_ann) anns_dp = Filter(anns_dp, self._is_ann)
dp = IterKeyZipper( dp = IterKeyZipper(
...@@ -137,8 +135,7 @@ class Caltech101(Dataset): ...@@ -137,8 +135,7 @@ class Caltech101(Dataset):
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp = Filter(dp, self._is_not_background_image) dp = Filter(dp, self._is_not_background_image)
return sorted({pathlib.Path(path).parent.name for path, _ in dp}) return sorted({pathlib.Path(path).parent.name for path, _ in dp})
...@@ -185,13 +182,11 @@ class Caltech256(Dataset): ...@@ -185,13 +182,11 @@ class Caltech256(Dataset):
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = TarArchiveReader(dp)
dp = Filter(dp, self._is_not_rogue_file) dp = Filter(dp, self._is_not_rogue_file)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp} dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
return [name.split(".")[1] for name in sorted(dir_names)] return [name.split(".")[1] for name in sorted(dir_names)]
...@@ -8,7 +8,6 @@ from torchdata.datapipes.iter import ( ...@@ -8,7 +8,6 @@ from torchdata.datapipes.iter import (
Mapper, Mapper,
Shuffler, Shuffler,
Filter, Filter,
ZipArchiveReader,
Zipper, Zipper,
IterKeyZipper, IterKeyZipper,
) )
...@@ -154,8 +153,6 @@ class CelebA(Dataset): ...@@ -154,8 +153,6 @@ class CelebA(Dataset):
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split)) splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE) splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
images_dp = ZipArchiveReader(images_dp)
anns_dp = Zipper( anns_dp = Zipper(
*[ *[
CelebACSVParser(dp, fieldnames=fieldnames) CelebACSVParser(dp, fieldnames=fieldnames)
......
...@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import ( ...@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Filter, Filter,
Mapper, Mapper,
TarArchiveReader,
Shuffler, Shuffler,
) )
from torchvision.prototype.datasets.decoder import raw from torchvision.prototype.datasets.decoder import raw
...@@ -85,7 +84,6 @@ class _CifarBase(Dataset): ...@@ -85,7 +84,6 @@ class _CifarBase(Dataset):
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0] dp = resource_dps[0]
dp = TarArchiveReader(dp)
dp = Filter(dp, functools.partial(self._is_data_file, config=config)) dp = Filter(dp, functools.partial(self._is_data_file, config=config))
dp = Mapper(dp, self._unpickle) dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
...@@ -93,8 +91,7 @@ class _CifarBase(Dataset): ...@@ -93,8 +91,7 @@ class _CifarBase(Dataset):
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> List[str]: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp = Filter(dp, path_comparator("name", self._META_FILE_NAME)) dp = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp = Mapper(dp, self._unpickle) dp = Mapper(dp, self._unpickle)
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])
......
...@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import ( ...@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import (
Shuffler, Shuffler,
Filter, Filter,
Demultiplexer, Demultiplexer,
ZipArchiveReader,
Grouper, Grouper,
IterKeyZipper, IterKeyZipper,
JsonParser, JsonParser,
...@@ -180,13 +179,10 @@ class Coco(Dataset): ...@@ -180,13 +179,10 @@ class Coco(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
images_dp, meta_dp = resource_dps images_dp, meta_dp = resource_dps
images_dp = ZipArchiveReader(images_dp)
if config.annotations is None: if config.annotations is None:
dp = Shuffler(images_dp) dp = Shuffler(images_dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
meta_dp = ZipArchiveReader(meta_dp)
meta_dp = Filter( meta_dp = Filter(
meta_dp, meta_dp,
self._filter_meta_files, self._filter_meta_files,
...@@ -234,8 +230,7 @@ class Coco(Dataset): ...@@ -234,8 +230,7 @@ class Coco(Dataset):
config = self.default_config config = self.default_config
resources = self.resources(config) resources = self.resources(config)
dp = resources[1].to_datapipe(pathlib.Path(root) / self.name) dp = resources[1].load(pathlib.Path(root) / self.name)
dp = ZipArchiveReader(dp)
dp = Filter( dp = Filter(
dp, self._filter_meta_files, fn_kwargs=dict(split=config.split, year=config.year, annotations="instances") dp, self._filter_meta_files, fn_kwargs=dict(split=config.split, year=config.year, annotations="instances")
) )
......
...@@ -9,8 +9,8 @@ from torchvision.prototype.datasets.utils import ( ...@@ -9,8 +9,8 @@ from torchvision.prototype.datasets.utils import (
Dataset, Dataset,
DatasetConfig, DatasetConfig,
DatasetInfo, DatasetInfo,
HttpResource,
OnlineResource, OnlineResource,
ManualDownloadResource,
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
...@@ -25,6 +25,11 @@ from torchvision.prototype.features import Label, DEFAULT ...@@ -25,6 +25,11 @@ from torchvision.prototype.features import Label, DEFAULT
from torchvision.prototype.utils._internal import FrozenMapping from torchvision.prototype.utils._internal import FrozenMapping
class ImageNetResource(ManualDownloadResource):
def __init__(self, **kwargs: Any) -> None:
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)
class ImageNetLabel(Label): class ImageNetLabel(Label):
wnid: Optional[str] wnid: Optional[str]
...@@ -81,10 +86,10 @@ class ImageNet(Dataset): ...@@ -81,10 +86,10 @@ class ImageNet(Dataset):
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def resources(self, config: DatasetConfig) -> List[OnlineResource]:
name = "test_v10102019" if config.split == "test" else config.split name = "test_v10102019" if config.split == "test" else config.split
images = HttpResource(f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name]) images = ImageNetResource(file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])
devkit = HttpResource( devkit = ImageNetResource(
"ILSVRC2012_devkit_t12.tar.gz", file_name="ILSVRC2012_devkit_t12.tar.gz",
sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",
) )
...@@ -139,15 +144,12 @@ class ImageNet(Dataset): ...@@ -139,15 +144,12 @@ class ImageNet(Dataset):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
images_dp, devkit_dp = resource_dps images_dp, devkit_dp = resource_dps
images_dp = TarArchiveReader(images_dp)
if config.split == "train": if config.split == "train":
# the train archive is a tar of tars # the train archive is a tar of tars
dp = TarArchiveReader(images_dp) dp = TarArchiveReader(images_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_train_data) dp = Mapper(dp, self._collate_train_data)
elif config.split == "val": elif config.split == "val":
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt")) devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
devkit_dp = LineReader(devkit_dp, return_path=False) devkit_dp = LineReader(devkit_dp, return_path=False)
devkit_dp = Mapper(devkit_dp, int) devkit_dp = Mapper(devkit_dp, int)
...@@ -177,8 +179,7 @@ class ImageNet(Dataset): ...@@ -177,8 +179,7 @@ class ImageNet(Dataset):
def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]: def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]:
resources = self.resources(self.default_config) resources = self.resources(self.default_config)
devkit_dp = resources[1].to_datapipe(root / self.name) devkit_dp = resources[1].load(root / self.name)
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) devkit_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
meta = next(iter(devkit_dp))[1] meta = next(iter(devkit_dp))[1]
......
...@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import ( ...@@ -11,7 +11,6 @@ from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Demultiplexer, Demultiplexer,
Mapper, Mapper,
ZipArchiveReader,
Zipper, Zipper,
Shuffler, Shuffler,
) )
...@@ -310,7 +309,6 @@ class EMNIST(_MNISTBase): ...@@ -310,7 +309,6 @@ class EMNIST(_MNISTBase):
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0] archive_dp = resource_dps[0]
archive_dp = ZipArchiveReader(archive_dp)
images_dp, labels_dp = Demultiplexer( images_dp, labels_dp = Demultiplexer(
archive_dp, archive_dp,
2, 2,
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
TarArchiveReader,
Shuffler, Shuffler,
Demultiplexer, Demultiplexer,
Filter, Filter,
...@@ -129,7 +128,6 @@ class SBD(Dataset): ...@@ -129,7 +128,6 @@ class SBD(Dataset):
archive_dp, extra_split_dp = resource_dps archive_dp, extra_split_dp = resource_dps
archive_dp = resource_dps[0] archive_dp = resource_dps[0]
archive_dp = TarArchiveReader(archive_dp)
split_dp, images_dp, anns_dp = Demultiplexer( split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp, archive_dp,
3, 3,
...@@ -155,8 +153,7 @@ class SBD(Dataset): ...@@ -155,8 +153,7 @@ class SBD(Dataset):
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp = Filter(dp, path_comparator("name", "category_names.m")) dp = Filter(dp, path_comparator("name", "category_names.m"))
dp = LineReader(dp) dp = LineReader(dp)
dp = Mapper(dp, bytes.decode, input_col=1) dp = Mapper(dp, bytes.decode, input_col=1)
......
...@@ -30,11 +30,11 @@ class SEMEION(Dataset): ...@@ -30,11 +30,11 @@ class SEMEION(Dataset):
) )
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def resources(self, config: DatasetConfig) -> List[OnlineResource]:
archive = HttpResource( data = HttpResource(
"http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data", "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data",
sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1", sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1",
) )
return [archive] return [data]
def _collate_and_decode_sample( def _collate_and_decode_sample(
self, self,
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
Mapper, Mapper,
TarArchiveReader,
Shuffler, Shuffler,
Filter, Filter,
Demultiplexer, Demultiplexer,
...@@ -119,7 +118,6 @@ class VOC(Dataset): ...@@ -119,7 +118,6 @@ class VOC(Dataset):
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0] archive_dp = resource_dps[0]
archive_dp = TarArchiveReader(archive_dp)
split_dp, images_dp, anns_dp = Demultiplexer( split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp, archive_dp,
3, 3,
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import argparse import argparse
import csv import csv
import pathlib
import sys import sys
from torchvision.prototype import datasets from torchvision.prototype import datasets
...@@ -10,7 +11,7 @@ from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR ...@@ -10,7 +11,7 @@ from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
def main(*names, force=False): def main(*names, force=False):
root = datasets.home() root = pathlib.Path(datasets.home())
for name in names: for name in names:
path = BUILTIN_DIR / f"{name}.categories" path = BUILTIN_DIR / f"{name}.categories"
...@@ -24,7 +25,8 @@ def main(*names, force=False): ...@@ -24,7 +25,8 @@ def main(*names, force=False):
continue continue
with open(path, "w", newline="") as file: with open(path, "w", newline="") as file:
csv.writer(file).writerows(categories) for category in categories:
csv.writer(file).writerow((category,) if isinstance(category, str) else category)
def parse_args(argv=None): def parse_args(argv=None):
......
from . import _internal from . import _internal
from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset from ._dataset import DatasetType, DatasetConfig, DatasetInfo, Dataset
from ._query import SampleQuery from ._query import SampleQuery
from ._resource import LocalResource, OnlineResource, HttpResource, GDriveResource from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource
...@@ -172,12 +172,13 @@ class Dataset(abc.ABC): ...@@ -172,12 +172,13 @@ class Dataset(abc.ABC):
def supports_sharded(self) -> bool: def supports_sharded(self) -> bool:
return False return False
def to_datapipe( def load(
self, self,
root: Union[str, pathlib.Path], root: Union[str, pathlib.Path],
*, *,
config: Optional[DatasetConfig] = None, config: Optional[DatasetConfig] = None,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = None,
skip_integrity_check: bool = False,
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
if not config: if not config:
config = self.info.default_config config = self.info.default_config
...@@ -188,7 +189,9 @@ class Dataset(abc.ABC): ...@@ -188,7 +189,9 @@ class Dataset(abc.ABC):
return _make_sharded_datapipe(root, dataset_size) return _make_sharded_datapipe(root, dataset_size)
self.info.check_dependencies() self.info.check_dependencies()
resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)] resource_dps = [
resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config)
]
return self._make_datapipe(resource_dps, config=config, decoder=decoder) return self._make_datapipe(resource_dps, config=config, decoder=decoder)
def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]:
......
import os.path import abc
import hashlib
import itertools
import pathlib import pathlib
from typing import Optional, Union import warnings
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn
from urllib.parse import urlparse from urllib.parse import urlparse
from torch.utils.data import IterDataPipe from torchdata.datapipes.iter import (
from torch.utils.data.datapipes.iter import IterableWrapper IterableWrapper,
from torchdata.datapipes.iter import IoPathFileLoader FileLister,
FileLoader,
IterDataPipe,
ZipArchiveReader,
TarArchiveReader,
RarArchiveLoader,
)
from torchvision.datasets.utils import (
download_url,
_detect_file_type,
extract_archive,
_decompress,
download_file_from_google_drive,
)
# FIXME class OnlineResource(abc.ABC):
def compute_sha256(path: pathlib.Path) -> str: def __init__(
return "" self,
*,
file_name: str,
sha256: Optional[str] = None,
decompress: bool = False,
extract: bool = False,
preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] = None,
loader: Optional[Callable[[pathlib.Path], IterDataPipe[Tuple[str, IO]]]] = None,
) -> None:
self.file_name = file_name
self.sha256 = sha256
if preprocess and (decompress or extract):
warnings.warn("The parameters 'decompress' and 'extract' are ignored when 'preprocess' is passed.")
elif extract:
preprocess = self._extract
elif decompress:
preprocess = self._decompress
self._preprocess = preprocess
class LocalResource: if loader is None:
def __init__(self, path: Union[str, pathlib.Path], *, sha256: Optional[str] = None) -> None: loader = self._default_loader
self.path = pathlib.Path(path).expanduser().resolve() self._loader = loader
self.file_name = self.path.name
self.sha256 = sha256 or compute_sha256(self.path)
def to_datapipe(self) -> IterDataPipe: @staticmethod
return IoPathFileLoader(IterableWrapper((str(self.path),)), mode="rb") # type: ignore def _extract(file: pathlib.Path) -> pathlib.Path:
return pathlib.Path(
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)
)
@staticmethod
def _decompress(file: pathlib.Path) -> pathlib.Path:
return pathlib.Path(_decompress(str(file), remove_finished=True))
class OnlineResource: def _default_loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
def __init__(self, url: str, *, sha256: str, file_name: str) -> None: if path.is_dir():
self.url = url return FileLoader(FileLister(str(path), recursive=True))
self.sha256 = sha256
self.file_name = file_name dp = FileLoader(IterableWrapper((str(path),)))
archive_loader = self._guess_archive_loader(path)
if archive_loader:
dp = archive_loader(dp)
return dp
_ARCHIVE_LOADERS = {
".tar": TarArchiveReader,
".zip": ZipArchiveReader,
".rar": RarArchiveLoader,
}
def _guess_archive_loader(
self, path: pathlib.Path
) -> Optional[Callable[[IterDataPipe[Tuple[str, IO]]], IterDataPipe[Tuple[str, IO]]]]:
try:
_, archive_type, _ = _detect_file_type(path.name)
except RuntimeError:
return None
return self._ARCHIVE_LOADERS.get(archive_type) # type: ignore[arg-type]
def to_datapipe(self, root: Union[str, pathlib.Path]) -> IterDataPipe: def load(
path = os.path.join(root, self.file_name) self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False
# FIXME ) -> IterDataPipe[Tuple[str, IO]]:
return IoPathFileLoader(IterableWrapper((str(path),)), mode="rb") # type: ignore root = pathlib.Path(root)
path = root / self.file_name
# Instead of the raw file, there might also be files with fewer suffixes after decompression or directories
# with no suffixes at all. Thus, we look for all paths that share the same name without suffixes as the raw
# file.
path_candidates = {file for file in path.parent.glob(path.name.replace("".join(path.suffixes), "") + "*")}
# If we don't find anything, we try to download the raw file.
if not path_candidates:
path_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)}
# If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps.
if path_candidates == {path}:
if self._preprocess:
path = self._preprocess(path)
# Otherwise we use the path with the fewest suffixes. This gives us the extracted > decompressed > raw priority
# that we want.
else:
path = min(path_candidates, key=lambda path: len(path.suffixes))
return self._loader(path)
@abc.abstractmethod
def _download(self, root: pathlib.Path) -> None:
pass
def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> pathlib.Path:
root = pathlib.Path(root)
self._download(root)
path = root / self.file_name
if self.sha256 and not skip_integrity_check:
self._check_sha256(path)
return path
def _check_sha256(self, path: pathlib.Path, *, chunk_size: int = 1024 * 1024) -> None:
hash = hashlib.sha256()
with open(path, "rb") as file:
for chunk in iter(lambda: file.read(chunk_size), b""):
hash.update(chunk)
sha256 = hash.hexdigest()
if sha256 != self.sha256:
raise RuntimeError(
f"After the download, the SHA256 checksum of {path} didn't match the expected one: "
f"{sha256} != {self.sha256}"
)
# TODO: add support for mirrors
# TODO: add support for http -> https
class HttpResource(OnlineResource): class HttpResource(OnlineResource):
def __init__(self, url: str, *, sha256: str, file_name: Optional[str] = None) -> None: def __init__(
if not file_name: self, url: str, *, file_name: Optional[str] = None, mirrors: Optional[Sequence[str]] = None, **kwargs: Any
file_name = os.path.basename(urlparse(url).path) ) -> None:
super().__init__(url, sha256=sha256, file_name=file_name) super().__init__(file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs)
self.url = url
self.mirrors = mirrors
def _download(self, root: pathlib.Path) -> None:
for url in itertools.chain((self.url,), self.mirrors or ()):
try:
download_url(url, str(root), filename=self.file_name, md5=None)
# TODO: make this more precise
except Exception:
continue
return
else:
# TODO: make this more informative
raise RuntimeError("Download failed!")
class GDriveResource(OnlineResource): class GDriveResource(OnlineResource):
def __init__(self, id: str, *, sha256: str, file_name: str) -> None: def __init__(self, id: str, **kwargs: Any) -> None:
# TODO: can we maybe do a head request to extract the file name? super().__init__(**kwargs)
url = f"https://drive.google.com/file/d/{id}/view" self.id = id
super().__init__(url, sha256=sha256, file_name=file_name)
def _download(self, root: pathlib.Path) -> None:
download_file_from_google_drive(self.id, root=str(root), filename=self.file_name, md5=None)
class ManualDownloadResource(OnlineResource):
def __init__(self, instructions: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.instructions = instructions
def _download(self, root: pathlib.Path) -> NoReturn:
raise RuntimeError(
f"The file {self.file_name} cannot be downloaded automatically. "
f"Please follow the instructions below and place it in {root}\n\n"
f"{self.instructions}"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment