Unverified Commit f9bee391 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

relax split requirement for prototype datasets (#5186)

* relax split requirement for prototype datasets

* remove obsolete tests

* appease mypy

* fix failing test

* fix load config test with default config
parent 28f72f16
......@@ -125,21 +125,6 @@ class TestDatasetInfo:
assert info.default_config == default_config
@pytest.mark.parametrize(
"valid_options",
[
pytest.param(None, id="default"),
pytest.param(dict(option=("value",)), id="no_split"),
],
)
def test_default_config_split_train(self, valid_options):
info = make_minimal_dataset_info(valid_options=valid_options)
assert info.default_config.split == "train"
def test_valid_options_split_but_no_train(self):
with pytest.raises(ValueError, match="'train' has to be a valid argument for option 'split'"):
make_minimal_dataset_info(valid_options=dict(split=("test",)))
@pytest.mark.parametrize(
("options", "expected_error_msg"),
[
......@@ -208,7 +193,7 @@ class TestDataset:
("config", "kwarg"),
[
pytest.param(*(datasets.utils.DatasetConfig(split="test"),) * 2, id="specific"),
pytest.param(make_minimal_dataset_info().default_config, None, id="default"),
pytest.param(DatasetMock().default_config, None, id="default"),
],
)
def test_load_config(self, config, kwarg):
......@@ -218,7 +203,7 @@ class TestDataset:
dataset.resources.assert_called_with(config)
(_, call_kwargs) = dataset._make_datapipe.call_args
_, call_kwargs = dataset._make_datapipe.call_args
assert call_kwargs["config"] == config
def test_missing_dependencies(self):
......
......@@ -45,14 +45,32 @@ class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
class _CifarBase(Dataset):
_FILE_NAME: str
_SHA256: str
_LABELS_KEY: str
_META_FILE_NAME: str
_CATEGORIES_KEY: str
@abc.abstractmethod
def _is_data_file(self, data: Tuple[str, io.IOBase], *, config: DatasetConfig) -> Optional[int]:
def _is_data_file(self, data: Tuple[str, io.IOBase], *, split: str) -> Optional[int]:
pass
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
type(self).__name__.lower(),
type=DatasetType.RAW,
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict(split=("train", "test")),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}",
sha256=self._SHA256,
)
]
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
_, file = data
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
......@@ -84,7 +102,7 @@ class _CifarBase(Dataset):
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, functools.partial(self._is_data_file, config=config))
dp = Filter(dp, functools.partial(self._is_data_file, split=config.split))
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp)
......@@ -102,53 +120,24 @@ class _CifarBase(Dataset):
class Cifar10(_CifarBase):
_FILE_NAME = "cifar-10-python.tar.gz"
_SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
_LABELS_KEY = "labels"
_META_FILE_NAME = "batches.meta"
_CATEGORIES_KEY = "label_names"
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool:
path = pathlib.Path(data[0])
return path.name.startswith("data" if config.split == "train" else "test")
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"cifar10",
type=DatasetType.RAW,
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz",
sha256="6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce",
)
]
return path.name.startswith("data" if split == "train" else "test")
class Cifar100(_CifarBase):
_FILE_NAME = "cifar-100-python.tar.gz"
_SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
_LABELS_KEY = "fine_labels"
_META_FILE_NAME = "meta"
_CATEGORIES_KEY = "fine_label_names"
def _is_data_file(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool:
path = pathlib.Path(data[0])
return path.name == cast(str, config.split)
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"cifar100",
type=DatasetType.RAW,
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict(
split=("train", "test"),
),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz",
sha256="85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7",
)
]
return path.name == split
......@@ -37,8 +37,7 @@ class OxfordIITPet(Dataset):
type=DatasetType.IMAGE,
homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/",
valid_options=dict(
# FIXME
split=("trainval", "test", "train"),
split=("trainval", "test"),
),
)
......
......@@ -38,7 +38,7 @@ class DatasetInfo:
citation: Optional[str] = None,
homepage: Optional[str] = None,
license: Optional[str] = None,
valid_options: Optional[Dict[str, Sequence]] = None,
valid_options: Optional[Dict[str, Sequence[Any]]] = None,
extra: Optional[Dict[str, Any]] = None,
) -> None:
self.name = name.lower()
......@@ -60,20 +60,10 @@ class DatasetInfo:
self.homepage = homepage
self.license = license
valid_split: Dict[str, Sequence] = dict(split=["train"])
if valid_options is None:
valid_options = valid_split
elif "split" not in valid_options:
valid_options.update(valid_split)
elif "train" not in valid_options["split"]:
raise ValueError(
f"'train' has to be a valid argument for option 'split', "
f"but found only {sequence_to_str(valid_options['split'], separate_last='and ')}."
)
self._valid_options: Dict[str, Sequence] = valid_options
self._valid_options = valid_options or dict()
self._configs = tuple(
DatasetConfig(**dict(zip(valid_options.keys(), combination)))
for combination in itertools.product(*valid_options.values())
DatasetConfig(**dict(zip(self._valid_options.keys(), combination)))
for combination in itertools.product(*self._valid_options.values())
)
self.extra = FrozenBunch(extra or dict())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment