Unverified Commit 5e0959a0 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add third party dependencies to prototype datasets (#4962)

* add third party dependencies to prototype datasets

* add test for Dataset.to_datapipe

* add missing annotation
parent 220e0ff9
......@@ -151,6 +151,12 @@ class TestDatasetInfo:
with pytest.raises(ValueError, match=expected_error_msg):
info.make_config(**options)
def test_check_dependencies(self):
dependency = "fake_dependency"
info = make_minimal_dataset_info(dependencies=(dependency,))
with pytest.raises(ModuleNotFoundError, match=dependency):
info.check_dependencies()
def test_repr(self, info):
output = repr(info)
......@@ -169,21 +175,14 @@ class TestDatasetInfo:
class TestDataset:
class DatasetMock(datasets.utils.Dataset):
def __init__(self, name="name", *, valid_options=None, resources=None):
self._name = name
self._valid_options = valid_options or dict(split=("train", "test"))
def __init__(self, info=None, *, resources=None):
self._info = info or make_minimal_dataset_info(valid_options=dict(split=("train", "test")))
self.resources = unittest.mock.Mock(return_value=[]) if resources is None else lambda config: resources
self._make_datapipe = unittest.mock.Mock()
super().__init__()
def _make_info(self):
return datasets.utils.DatasetInfo(
self._name,
type=datasets.utils.DatasetType.RAW,
categories=[],
valid_options=self._valid_options,
)
return self._info
def resources(self, config):
# This method is just defined to appease the ABC, but will be overwritten at instantiation
......@@ -195,14 +194,13 @@ class TestDataset:
def test_name(self):
name = "sentinel"
dataset = self.DatasetMock(name=name)
dataset = self.DatasetMock(make_minimal_dataset_info(name=name))
assert dataset.name == name
def test_default_config(self):
sentinel = "sentinel"
valid_options = dict(split=(sentinel, "train"))
dataset = self.DatasetMock(valid_options=valid_options)
dataset = self.DatasetMock(info=make_minimal_dataset_info(valid_options=dict(split=(sentinel, "train"))))
assert dataset.default_config == datasets.utils.DatasetConfig(split=sentinel)
......@@ -223,6 +221,12 @@ class TestDataset:
(_, call_kwargs) = dataset._make_datapipe.call_args
assert call_kwargs["config"] == config
def test_missing_dependencies(self):
dependency = "fake_dependency"
dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,)))
with pytest.raises(ModuleNotFoundError, match=dependency):
dataset.to_datapipe("root")
def test_resources(self, mocker):
resource_mock = mocker.Mock(spec=["to_datapipe"])
sentinel = object()
......
......@@ -30,6 +30,7 @@ class Caltech101(Dataset):
return DatasetInfo(
"caltech101",
type=DatasetType.IMAGE,
dependencies=("scipy",),
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
)
......
......@@ -45,6 +45,7 @@ class ImageNet(Dataset):
return DatasetInfo(
name,
type=DatasetType.IMAGE,
dependencies=("scipy",),
categories=categories,
homepage="https://www.image-net.org/",
valid_options=dict(split=("train", "val", "test")),
......
......@@ -37,6 +37,7 @@ class SBD(Dataset):
return DatasetInfo(
"sbd",
type=DatasetType.IMAGE,
dependencies=("scipy",),
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html",
valid_options=dict(
split=("train", "val", "train_noval"),
......
import abc
import csv
import enum
import importlib
import io
import itertools
import os
......@@ -32,6 +33,7 @@ class DatasetInfo:
name: str,
*,
type: Union[str, DatasetType],
dependencies: Sequence[str] = (),
categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None,
citation: Optional[str] = None,
homepage: Optional[str] = None,
......@@ -42,6 +44,8 @@ class DatasetInfo:
self.name = name.lower()
self.type = DatasetType[type.upper()] if isinstance(type, str) else type
self.dependecies = dependencies
if categories is None:
path = BUILTIN_DIR / f"{self.name}.categories"
categories = path if path.exists() else []
......@@ -107,6 +111,16 @@ class DatasetInfo:
return DatasetConfig(self.default_config, **options)
def check_dependencies(self) -> None:
for dependency in self.dependecies:
try:
importlib.import_module(dependency)
except ModuleNotFoundError as error:
raise ModuleNotFoundError(
f"Dataset '{self.name}' depends on the third-party package '{dependency}'. "
f"Please install it, for example with `pip install {dependency}`."
) from error
def __repr__(self) -> str:
items = [("name", self.name)]
for key in ("citation", "homepage", "license"):
......@@ -172,6 +186,8 @@ class Dataset(abc.ABC):
root = os.path.join(root, *config.values())
dataset_size = self.info.extra["sizes"][config]
return _make_sharded_datapipe(root, dataset_size)
self.info.check_dependencies()
resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)]
return self._make_datapipe(resource_dps, config=config, decoder=decoder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment