Unverified Commit 5ea23483 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve DatasetInfo for prototype datasets (#4746)



* cache dataset info

* try loading categories file by default
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 460b37d2
......@@ -21,16 +21,14 @@ from torchvision.prototype.datasets.utils import (
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, BUILTIN_DIR, read_mat
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat
class Caltech101(Dataset):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"caltech101",
type=DatasetType.IMAGE,
categories=BUILTIN_DIR / "caltech101.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
)
......@@ -144,12 +142,10 @@ class Caltech101(Dataset):
class Caltech256(Dataset):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"caltech256",
type=DatasetType.IMAGE,
categories=BUILTIN_DIR / "caltech256.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
)
......
......@@ -58,8 +58,7 @@ class CelebACSVParser(IterDataPipe[Tuple[str, Dict[str, str]]]):
class CelebA(Dataset):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"celeba",
type=DatasetType.IMAGE,
......
......@@ -25,7 +25,6 @@ from torchvision.prototype.datasets.utils import (
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
image_buffer_from_array,
path_comparator,
)
......@@ -110,12 +109,10 @@ class Cifar10(_CifarBase):
path = pathlib.Path(data[0])
return path.name.startswith("data" if config.split == "train" else "test")
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"cifar10",
type=DatasetType.RAW,
categories=BUILTIN_DIR / "cifar10.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
)
......@@ -137,12 +134,10 @@ class Cifar100(_CifarBase):
path = pathlib.Path(data[0])
return path.name == cast(str, config.split)
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"cifar100",
type=DatasetType.RAW,
categories=BUILTIN_DIR / "cifar100.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict(
split=("train", "test"),
......
......@@ -35,8 +35,7 @@ HERE = pathlib.Path(__file__).parent
class Coco(Dataset):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"coco",
type=DatasetType.IMAGE,
......
......@@ -25,8 +25,7 @@ from torchvision.prototype.datasets.utils._internal import (
class ImageNet(Dataset):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
name = "imagenet"
categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories"))
......
......@@ -151,8 +151,7 @@ class _MNISTBase(Dataset):
class MNIST(_MNISTBase):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"mnist",
type=DatasetType.RAW,
......@@ -182,8 +181,7 @@ class MNIST(_MNISTBase):
class FashionMNIST(MNIST):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"fashionmnist",
type=DatasetType.RAW,
......@@ -215,8 +213,7 @@ class FashionMNIST(MNIST):
class KMNIST(MNIST):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"kmnist",
type=DatasetType.RAW,
......@@ -237,8 +234,7 @@ class KMNIST(MNIST):
class EMNIST(_MNISTBase):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"emnist",
type=DatasetType.RAW,
......@@ -335,8 +331,7 @@ class EMNIST(_MNISTBase):
class QMNIST(_MNISTBase):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"qmnist",
type=DatasetType.RAW,
......
......@@ -25,7 +25,6 @@ from torchvision.prototype.datasets.utils import (
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
read_mat,
getitem,
path_accessor,
......@@ -34,12 +33,10 @@ from torchvision.prototype.datasets.utils._internal import (
class SBD(Dataset):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"sbd",
type=DatasetType.IMAGE,
categories=BUILTIN_DIR / "caltech256.categories",
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html",
valid_options=dict(
split=("train", "val", "train_noval"),
......
......@@ -35,8 +35,7 @@ HERE = pathlib.Path(__file__).parent
class VOC(Dataset):
@property
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"voc",
type=DatasetType.IMAGE,
......
......@@ -12,7 +12,7 @@ from torchvision.prototype.datasets.utils._internal import (
sequence_to_str,
)
from ._internal import FrozenBunch, make_repr
from ._internal import FrozenBunch, make_repr, BUILTIN_DIR
from ._resource import OnlineResource
......@@ -42,8 +42,9 @@ class DatasetInfo:
self.type = DatasetType[type.upper()] if isinstance(type, str) else type
if categories is None:
categories = []
elif isinstance(categories, int):
path = BUILTIN_DIR / f"{self.name}.categories"
categories = path if path.exists() else []
if isinstance(categories, int):
categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)):
path = pathlib.Path(categories).expanduser().resolve()
......@@ -112,11 +113,17 @@ class DatasetInfo:
class Dataset(abc.ABC):
@property
def __init__(self) -> None:
self._info = self._make_info()
@abc.abstractmethod
def info(self) -> DatasetInfo:
def _make_info(self) -> DatasetInfo:
pass
@property
def info(self) -> DatasetInfo:
return self._info
@property
def name(self) -> str:
return self.info.name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment