"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "143f667fb89a70f7d3d483091d4d0084ea725127"
Unverified Commit e4a4a29a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

streamline category file generation for prototype datasets (#4642)



* streamline category file generation for prototype datasets

* cleanup
Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent 9bee9cc4
import io import io
import pathlib import pathlib
import re import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -21,9 +21,7 @@ from torchvision.prototype.datasets.utils import ( ...@@ -21,9 +21,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource, OnlineResource,
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, BUILTIN_DIR, read_mat
HERE = pathlib.Path(__file__).parent
class Caltech101(Dataset): class Caltech101(Dataset):
...@@ -32,7 +30,7 @@ class Caltech101(Dataset): ...@@ -32,7 +30,7 @@ class Caltech101(Dataset):
return DatasetInfo( return DatasetInfo(
"caltech101", "caltech101",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
categories=HERE / "caltech101.categories", categories=BUILTIN_DIR / "caltech101.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
) )
...@@ -135,12 +133,11 @@ class Caltech101(Dataset): ...@@ -135,12 +133,11 @@ class Caltech101(Dataset):
) )
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp) dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_background_image) dp: IterDataPipe = Filter(dp, self._is_not_background_image)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp} return sorted({pathlib.Path(path).parent.name for path, _ in dp})
create_categories_file(HERE, self.name, sorted(dir_names))
class Caltech256(Dataset): class Caltech256(Dataset):
...@@ -149,7 +146,7 @@ class Caltech256(Dataset): ...@@ -149,7 +146,7 @@ class Caltech256(Dataset):
return DatasetInfo( return DatasetInfo(
"caltech256", "caltech256",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
categories=HERE / "caltech256.categories", categories=BUILTIN_DIR / "caltech256.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
) )
...@@ -192,17 +189,8 @@ class Caltech256(Dataset): ...@@ -192,17 +189,8 @@ class Caltech256(Dataset):
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp) dp = TarArchiveReader(dp)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp} dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
categories = [name.split(".")[1] for name in sorted(dir_names)] return [name.split(".")[1] for name in sorted(dir_names)]
create_categories_file(HERE, self.name, categories)
if __name__ == "__main__":
from torchvision.prototype.datasets import home
root = home()
Caltech101().generate_categories_file(root)
Caltech256().generate_categories_file(root)
...@@ -24,16 +24,14 @@ from torchvision.prototype.datasets.utils import ( ...@@ -24,16 +24,14 @@ from torchvision.prototype.datasets.utils import (
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
create_categories_file,
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
image_buffer_from_array, image_buffer_from_array,
path_comparator, path_comparator,
) )
__all__ = ["Cifar10", "Cifar100"] __all__ = ["Cifar10", "Cifar100"]
HERE = pathlib.Path(__file__).parent
class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]): class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None: def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
...@@ -95,13 +93,12 @@ class _CifarBase(Dataset): ...@@ -95,13 +93,12 @@ class _CifarBase(Dataset):
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp) dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME)) dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp: IterDataPipe = Mapper(dp, self._unpickle) dp: IterDataPipe = Mapper(dp, self._unpickle)
categories = next(iter(dp))[self._CATEGORIES_KEY] return next(iter(dp))[self._CATEGORIES_KEY]
create_categories_file(HERE, self.name, categories)
class Cifar10(_CifarBase): class Cifar10(_CifarBase):
...@@ -118,7 +115,7 @@ class Cifar10(_CifarBase): ...@@ -118,7 +115,7 @@ class Cifar10(_CifarBase):
return DatasetInfo( return DatasetInfo(
"cifar10", "cifar10",
type=DatasetType.RAW, type=DatasetType.RAW,
categories=HERE / "cifar10.categories", categories=BUILTIN_DIR / "cifar10.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html", homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
) )
...@@ -145,7 +142,7 @@ class Cifar100(_CifarBase): ...@@ -145,7 +142,7 @@ class Cifar100(_CifarBase):
return DatasetInfo( return DatasetInfo(
"cifar100", "cifar100",
type=DatasetType.RAW, type=DatasetType.RAW,
categories=HERE / "cifar100.categories", categories=BUILTIN_DIR / "cifar100.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html", homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict( valid_options=dict(
split=("train", "test"), split=("train", "test"),
...@@ -159,11 +156,3 @@ class Cifar100(_CifarBase): ...@@ -159,11 +156,3 @@ class Cifar100(_CifarBase):
sha256="85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7", sha256="85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7",
) )
] ]
if __name__ == "__main__":
from torchvision.prototype.datasets import home
root = home()
Cifar10().generate_categories_file(root)
Cifar100().generate_categories_file(root)
import io import io
import pathlib import pathlib
import re import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -24,16 +24,14 @@ from torchvision.prototype.datasets.utils import ( ...@@ -24,16 +24,14 @@ from torchvision.prototype.datasets.utils import (
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import ( from torchvision.prototype.datasets.utils._internal import (
create_categories_file,
INFINITE_BUFFER_SIZE, INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
read_mat, read_mat,
getitem, getitem,
path_accessor, path_accessor,
path_comparator, path_comparator,
) )
HERE = pathlib.Path(__file__).parent
class SBD(Dataset): class SBD(Dataset):
@property @property
...@@ -41,7 +39,7 @@ class SBD(Dataset): ...@@ -41,7 +39,7 @@ class SBD(Dataset):
return DatasetInfo( return DatasetInfo(
"sbd", "sbd",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
categories=HERE / "caltech256.categories", categories=BUILTIN_DIR / "caltech256.categories",
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", homepage="http://home.bharathh.info/pubs/codes/SBD/download.html",
valid_options=dict( valid_options=dict(
split=("train", "val", "train_noval"), split=("train", "val", "train_noval"),
...@@ -158,7 +156,7 @@ class SBD(Dataset): ...@@ -158,7 +156,7 @@ class SBD(Dataset):
) )
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None: def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name) dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp) dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", "category_names.m")) dp: IterDataPipe = Filter(dp, path_comparator("name", "category_names.m"))
...@@ -172,15 +170,4 @@ class SBD(Dataset): ...@@ -172,15 +170,4 @@ class SBD(Dataset):
# the first and last line contain no information # the first and last line contain no information
for line in lines[1:-1] for line in lines[1:-1]
] ]
categories = tuple( return tuple(zip(*sorted(categories_and_labels, key=lambda category_and_label: int(category_and_label[1]))))[0]
zip(*sorted(categories_and_labels, key=lambda category_and_label: int(category_and_label[1])))
)[0]
create_categories_file(HERE, self.name, categories)
if __name__ == "__main__":
from torchvision.prototype.datasets import home
root = home()
SBD().generate_categories_file(root)
import argparse
import sys
import unittest.mock
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=r"The categories file .+? does not exist.", category=UserWarning)
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find
from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
def main(*names, force=False):
root = datasets.home()
for name in names:
file = BUILTIN_DIR / f"{name}.categories"
if file.exists() and not force:
continue
dataset = find(name)
try:
with unittest.mock.patch(
"torchvision.prototype.datasets.utils._dataset.DatasetInfo._read_categories_file", return_value=[]
):
categories = dataset._generate_categories(root)
except NotImplementedError:
continue
with open(file, "w") as fh:
fh.write("\n".join(categories) + "\n")
def parse_args(argv=None):
parser = argparse.ArgumentParser(prog="torchvision.prototype.datasets.generate_category_files.py")
parser.add_argument(
"names",
nargs="?",
type=str,
help="Names of datasets to generate category files for. If omitted, all datasets will be used.",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
help="Force regeneration of category files.",
)
args = parser.parse_args(argv or sys.argv[1:])
if not args.names:
args.names = datasets.list()
return args
if __name__ == "__main__":
args = parse_args()
try:
main(*args.names, force=args.force)
except Exception as error:
msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
sys.exit(1)
...@@ -4,6 +4,7 @@ import io ...@@ -4,6 +4,7 @@ import io
import os import os
import pathlib import pathlib
import textwrap import textwrap
import warnings
from collections import Mapping from collections import Mapping
from typing import ( from typing import (
Any, Any,
...@@ -117,8 +118,7 @@ class DatasetInfo: ...@@ -117,8 +118,7 @@ class DatasetInfo:
elif isinstance(categories, int): elif isinstance(categories, int):
categories = [str(label) for label in range(categories)] categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)): elif isinstance(categories, (str, pathlib.Path)):
with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh: categories = self._read_categories_file(pathlib.Path(categories).expanduser().resolve())
categories = [line.strip() for line in fh]
self.categories = tuple(categories) self.categories = tuple(categories)
self.citation = citation self.citation = citation
...@@ -137,6 +137,17 @@ class DatasetInfo: ...@@ -137,6 +137,17 @@ class DatasetInfo:
) )
self._valid_options: Dict[str, Sequence] = valid_options self._valid_options: Dict[str, Sequence] = valid_options
@staticmethod
def _read_categories_file(path: pathlib.Path) -> List[str]:
if not path.exists() or not path.is_file():
warnings.warn(
f"The categories file {path} does not exist. Continuing without loaded categories.", UserWarning
)
return []
with open(path, "r") as file:
return [line.strip() for line in file]
@property @property
def default_config(self) -> DatasetConfig: def default_config(self) -> DatasetConfig:
return DatasetConfig({name: valid_args[0] for name, valid_args in self._valid_options.items()}) return DatasetConfig({name: valid_args[0] for name, valid_args in self._valid_options.items()})
...@@ -219,3 +230,6 @@ class Dataset(abc.ABC): ...@@ -219,3 +230,6 @@ class Dataset(abc.ABC):
resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)] resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)]
return self._make_datapipe(resource_dps, config=config, decoder=decoder) return self._make_datapipe(resource_dps, config=config, decoder=decoder)
def _generate_categories(self, root: pathlib.Path) -> Sequence[str]:
raise NotImplementedError
...@@ -15,9 +15,9 @@ from torch.utils.data import IterDataPipe ...@@ -15,9 +15,9 @@ from torch.utils.data import IterDataPipe
__all__ = [ __all__ = [
"INFINITE_BUFFER_SIZE", "INFINITE_BUFFER_SIZE",
"BUILTIN_DIR",
"sequence_to_str", "sequence_to_str",
"add_suggestion", "add_suggestion",
"create_categories_file",
"read_mat", "read_mat",
"image_buffer_from_array", "image_buffer_from_array",
"SequenceIterator", "SequenceIterator",
...@@ -35,6 +35,8 @@ D = TypeVar("D") ...@@ -35,6 +35,8 @@ D = TypeVar("D")
# pseudo-infinite until a true infinite buffer is supported by all datapipes # pseudo-infinite until a true infinite buffer is supported by all datapipes
INFINITE_BUFFER_SIZE = 1_000_000_000 INFINITE_BUFFER_SIZE = 1_000_000_000
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1: if len(seq) == 1:
...@@ -60,11 +62,6 @@ def add_suggestion( ...@@ -60,11 +62,6 @@ def add_suggestion(
return f"{msg.strip()} {hint}" return f"{msg.strip()} {hint}"
def create_categories_file(root: Union[str, pathlib.Path], name: str, categories: Sequence[str]) -> None:
with open(pathlib.Path(root) / f"{name}.categories", "w") as fh:
fh.write("\n".join(categories) + "\n")
def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
try: try:
import scipy.io as sio import scipy.io as sio
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment