Unverified Commit 5e0959a0 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add third party dependencies to prototype datasets (#4962)

* add third party dependencies to prototype datasets

* add test for Dataset.to_datapipe

* add missing annotation
parent 220e0ff9
...@@ -151,6 +151,12 @@ class TestDatasetInfo: ...@@ -151,6 +151,12 @@ class TestDatasetInfo:
with pytest.raises(ValueError, match=expected_error_msg): with pytest.raises(ValueError, match=expected_error_msg):
info.make_config(**options) info.make_config(**options)
def test_check_dependencies(self):
dependency = "fake_dependency"
info = make_minimal_dataset_info(dependencies=(dependency,))
with pytest.raises(ModuleNotFoundError, match=dependency):
info.check_dependencies()
def test_repr(self, info): def test_repr(self, info):
output = repr(info) output = repr(info)
...@@ -169,21 +175,14 @@ class TestDatasetInfo: ...@@ -169,21 +175,14 @@ class TestDatasetInfo:
class TestDataset: class TestDataset:
class DatasetMock(datasets.utils.Dataset): class DatasetMock(datasets.utils.Dataset):
def __init__(self, name="name", *, valid_options=None, resources=None): def __init__(self, info=None, *, resources=None):
self._name = name self._info = info or make_minimal_dataset_info(valid_options=dict(split=("train", "test")))
self._valid_options = valid_options or dict(split=("train", "test"))
self.resources = unittest.mock.Mock(return_value=[]) if resources is None else lambda config: resources self.resources = unittest.mock.Mock(return_value=[]) if resources is None else lambda config: resources
self._make_datapipe = unittest.mock.Mock() self._make_datapipe = unittest.mock.Mock()
super().__init__() super().__init__()
def _make_info(self): def _make_info(self):
return datasets.utils.DatasetInfo( return self._info
self._name,
type=datasets.utils.DatasetType.RAW,
categories=[],
valid_options=self._valid_options,
)
def resources(self, config): def resources(self, config):
# This method is just defined to appease the ABC, but will be overwritten at instantiation # This method is just defined to appease the ABC, but will be overwritten at instantiation
...@@ -195,14 +194,13 @@ class TestDataset: ...@@ -195,14 +194,13 @@ class TestDataset:
def test_name(self): def test_name(self):
name = "sentinel" name = "sentinel"
dataset = self.DatasetMock(name=name) dataset = self.DatasetMock(make_minimal_dataset_info(name=name))
assert dataset.name == name assert dataset.name == name
def test_default_config(self): def test_default_config(self):
sentinel = "sentinel" sentinel = "sentinel"
valid_options = dict(split=(sentinel, "train")) dataset = self.DatasetMock(info=make_minimal_dataset_info(valid_options=dict(split=(sentinel, "train"))))
dataset = self.DatasetMock(valid_options=valid_options)
assert dataset.default_config == datasets.utils.DatasetConfig(split=sentinel) assert dataset.default_config == datasets.utils.DatasetConfig(split=sentinel)
...@@ -223,6 +221,12 @@ class TestDataset: ...@@ -223,6 +221,12 @@ class TestDataset:
(_, call_kwargs) = dataset._make_datapipe.call_args (_, call_kwargs) = dataset._make_datapipe.call_args
assert call_kwargs["config"] == config assert call_kwargs["config"] == config
def test_missing_dependencies(self):
dependency = "fake_dependency"
dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,)))
with pytest.raises(ModuleNotFoundError, match=dependency):
dataset.to_datapipe("root")
def test_resources(self, mocker): def test_resources(self, mocker):
resource_mock = mocker.Mock(spec=["to_datapipe"]) resource_mock = mocker.Mock(spec=["to_datapipe"])
sentinel = object() sentinel = object()
......
...@@ -30,6 +30,7 @@ class Caltech101(Dataset): ...@@ -30,6 +30,7 @@ class Caltech101(Dataset):
return DatasetInfo( return DatasetInfo(
"caltech101", "caltech101",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
dependencies=("scipy",),
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
) )
......
...@@ -45,6 +45,7 @@ class ImageNet(Dataset): ...@@ -45,6 +45,7 @@ class ImageNet(Dataset):
return DatasetInfo( return DatasetInfo(
name, name,
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
dependencies=("scipy",),
categories=categories, categories=categories,
homepage="https://www.image-net.org/", homepage="https://www.image-net.org/",
valid_options=dict(split=("train", "val", "test")), valid_options=dict(split=("train", "val", "test")),
......
...@@ -37,6 +37,7 @@ class SBD(Dataset): ...@@ -37,6 +37,7 @@ class SBD(Dataset):
return DatasetInfo( return DatasetInfo(
"sbd", "sbd",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
dependencies=("scipy",),
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", homepage="http://home.bharathh.info/pubs/codes/SBD/download.html",
valid_options=dict( valid_options=dict(
split=("train", "val", "train_noval"), split=("train", "val", "train_noval"),
......
import abc import abc
import csv import csv
import enum import enum
import importlib
import io import io
import itertools import itertools
import os import os
...@@ -32,6 +33,7 @@ class DatasetInfo: ...@@ -32,6 +33,7 @@ class DatasetInfo:
name: str, name: str,
*, *,
type: Union[str, DatasetType], type: Union[str, DatasetType],
dependencies: Sequence[str] = (),
categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None,
citation: Optional[str] = None, citation: Optional[str] = None,
homepage: Optional[str] = None, homepage: Optional[str] = None,
...@@ -42,6 +44,8 @@ class DatasetInfo: ...@@ -42,6 +44,8 @@ class DatasetInfo:
self.name = name.lower() self.name = name.lower()
self.type = DatasetType[type.upper()] if isinstance(type, str) else type self.type = DatasetType[type.upper()] if isinstance(type, str) else type
self.dependecies = dependencies
if categories is None: if categories is None:
path = BUILTIN_DIR / f"{self.name}.categories" path = BUILTIN_DIR / f"{self.name}.categories"
categories = path if path.exists() else [] categories = path if path.exists() else []
...@@ -107,6 +111,16 @@ class DatasetInfo: ...@@ -107,6 +111,16 @@ class DatasetInfo:
return DatasetConfig(self.default_config, **options) return DatasetConfig(self.default_config, **options)
def check_dependencies(self) -> None:
for dependency in self.dependecies:
try:
importlib.import_module(dependency)
except ModuleNotFoundError as error:
raise ModuleNotFoundError(
f"Dataset '{self.name}' depends on the third-party package '{dependency}'. "
f"Please install it, for example with `pip install {dependency}`."
) from error
def __repr__(self) -> str: def __repr__(self) -> str:
items = [("name", self.name)] items = [("name", self.name)]
for key in ("citation", "homepage", "license"): for key in ("citation", "homepage", "license"):
...@@ -172,6 +186,8 @@ class Dataset(abc.ABC): ...@@ -172,6 +186,8 @@ class Dataset(abc.ABC):
root = os.path.join(root, *config.values()) root = os.path.join(root, *config.values())
dataset_size = self.info.extra["sizes"][config] dataset_size = self.info.extra["sizes"][config]
return _make_sharded_datapipe(root, dataset_size) return _make_sharded_datapipe(root, dataset_size)
self.info.check_dependencies()
resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)] resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)]
return self._make_datapipe(resource_dps, config=config, decoder=decoder) return self._make_datapipe(resource_dps, config=config, decoder=decoder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment