Unverified Commit 5ea23483 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve DatasetInfo for prototype datasets (#4746)



* cache dataset info

* try loading categories file by default
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 460b37d2
...@@ -21,16 +21,14 @@ from torchvision.prototype.datasets.utils import ( ...@@ -21,16 +21,14 @@ from torchvision.prototype.datasets.utils import (
OnlineResource, OnlineResource,
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, BUILTIN_DIR, read_mat from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat
class Caltech101(Dataset): class Caltech101(Dataset):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"caltech101", "caltech101",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
categories=BUILTIN_DIR / "caltech101.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
) )
...@@ -144,12 +142,10 @@ class Caltech101(Dataset): ...@@ -144,12 +142,10 @@ class Caltech101(Dataset):
class Caltech256(Dataset): class Caltech256(Dataset):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"caltech256", "caltech256",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
categories=BUILTIN_DIR / "caltech256.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
) )
......
...@@ -58,8 +58,7 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]): ...@@ -58,8 +58,7 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
class CelebA(Dataset): class CelebA(Dataset):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"celeba", "celeba",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
......
...@@ -25,7 +25,6 @@ from torchvision.prototype.datasets.utils import ( ...@@ -25,7 +25,6 @@ from torchvision.prototype.datasets.utils import (
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
image_buffer_from_array, image_buffer_from_array,
path_comparator, path_comparator,
) )
...@@ -110,12 +109,10 @@ class Cifar10(_CifarBase): ...@@ -110,12 +109,10 @@ class Cifar10(_CifarBase):
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
return path.name.startswith("data" if config.split == "train" else "test") return path.name.startswith("data" if config.split == "train" else "test")
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"cifar10", "cifar10",
type=DatasetType.RAW, type=DatasetType.RAW,
categories=BUILTIN_DIR / "cifar10.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html", homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
) )
...@@ -137,12 +134,10 @@ class Cifar100(_CifarBase): ...@@ -137,12 +134,10 @@ class Cifar100(_CifarBase):
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
return path.name == cast(str, config.split) return path.name == cast(str, config.split)
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"cifar100", "cifar100",
type=DatasetType.RAW, type=DatasetType.RAW,
categories=BUILTIN_DIR / "cifar100.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html", homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict( valid_options=dict(
split=("train", "test"), split=("train", "test"),
......
...@@ -35,8 +35,7 @@ HERE = pathlib.Path(__file__).parent ...@@ -35,8 +35,7 @@ HERE = pathlib.Path(__file__).parent
class Coco(Dataset): class Coco(Dataset):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"coco", "coco",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
......
...@@ -25,8 +25,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -25,8 +25,7 @@ from torchvision.prototype.datasets.utils._internal import (
class ImageNet(Dataset): class ImageNet(Dataset):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
name = "imagenet" name = "imagenet"
categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories")) categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories"))
......
...@@ -151,8 +151,7 @@ class _MNISTBase(Dataset): ...@@ -151,8 +151,7 @@ class _MNISTBase(Dataset):
class MNIST(_MNISTBase): class MNIST(_MNISTBase):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"mnist", "mnist",
type=DatasetType.RAW, type=DatasetType.RAW,
...@@ -182,8 +181,7 @@ class MNIST(_MNISTBase): ...@@ -182,8 +181,7 @@ class MNIST(_MNISTBase):
class FashionMNIST(MNIST): class FashionMNIST(MNIST):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"fashionmnist", "fashionmnist",
type=DatasetType.RAW, type=DatasetType.RAW,
...@@ -215,8 +213,7 @@ class FashionMNIST(MNIST): ...@@ -215,8 +213,7 @@ class FashionMNIST(MNIST):
class KMNIST(MNIST): class KMNIST(MNIST):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"kmnist", "kmnist",
type=DatasetType.RAW, type=DatasetType.RAW,
...@@ -237,8 +234,7 @@ class KMNIST(MNIST): ...@@ -237,8 +234,7 @@ class KMNIST(MNIST):
class EMNIST(_MNISTBase): class EMNIST(_MNISTBase):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"emnist", "emnist",
type=DatasetType.RAW, type=DatasetType.RAW,
...@@ -335,8 +331,7 @@ class EMNIST(_MNISTBase): ...@@ -335,8 +331,7 @@ class EMNIST(_MNISTBase):
class QMNIST(_MNISTBase): class QMNIST(_MNISTBase):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"qmnist", "qmnist",
type=DatasetType.RAW, type=DatasetType.RAW,
......
...@@ -25,7 +25,6 @@ from torchvision.prototype.datasets.utils import ( ...@@ -25,7 +25,6 @@ from torchvision.prototype.datasets.utils import (
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
read_mat, read_mat,
getitem, getitem,
path_accessor, path_accessor,
...@@ -34,12 +33,10 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -34,12 +33,10 @@ from torchvision.prototype.datasets.utils._internal import (
class SBD(Dataset): class SBD(Dataset):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"sbd", "sbd",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
categories=BUILTIN_DIR / "caltech256.categories",
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", homepage="http://home.bharathh.info/pubs/codes/SBD/download.html",
valid_options=dict( valid_options=dict(
split=("train", "val", "train_noval"), split=("train", "val", "train_noval"),
......
...@@ -35,8 +35,7 @@ HERE = pathlib.Path(__file__).parent ...@@ -35,8 +35,7 @@ HERE = pathlib.Path(__file__).parent
class VOC(Dataset): class VOC(Dataset):
@property def _make_info(self) -> DatasetInfo:
def info(self) -> DatasetInfo:
return DatasetInfo( return DatasetInfo(
"voc", "voc",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
......
...@@ -12,7 +12,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -12,7 +12,7 @@ from torchvision.prototype.datasets.utils._internal import (
sequence_to_str, sequence_to_str,
) )
from ._internal import FrozenBunch, make_repr from ._internal import FrozenBunch, make_repr, BUILTIN_DIR
from ._resource import OnlineResource from ._resource import OnlineResource
...@@ -42,8 +42,9 @@ class DatasetInfo: ...@@ -42,8 +42,9 @@ class DatasetInfo:
self.type = DatasetType[type.upper()] if isinstance(type, str) else type self.type = DatasetType[type.upper()] if isinstance(type, str) else type
if categories is None: if categories is None:
categories = [] path = BUILTIN_DIR / f"{self.name}.categories"
elif isinstance(categories, int): categories = path if path.exists() else []
if isinstance(categories, int):
categories = [str(label) for label in range(categories)] categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)): elif isinstance(categories, (str, pathlib.Path)):
path = pathlib.Path(categories).expanduser().resolve() path = pathlib.Path(categories).expanduser().resolve()
...@@ -112,11 +113,17 @@ class DatasetInfo: ...@@ -112,11 +113,17 @@ class DatasetInfo:
class Dataset(abc.ABC): class Dataset(abc.ABC):
@property def __init__(self) -> None:
self._info = self._make_info()
@abc.abstractmethod @abc.abstractmethod
def info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
pass pass
@property
def info(self) -> DatasetInfo:
return self._info
@property @property
def name(self) -> str: def name(self) -> str:
return self.info.name return self.info.name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment