"examples/community/imagic_stable_diffusion.py" did not exist on "726aba089d12503249d824bbaf4070f47d0fe44d"
Unverified Commit e4a4a29a authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

streamline category file generation for prototype datasets (#4642)



* streamline category file generation for prototype datasets

* cleanup
Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent 9bee9cc4
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
......@@ -21,9 +21,7 @@ from torchvision.prototype.datasets.utils import (
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import create_categories_file, INFINITE_BUFFER_SIZE, read_mat
HERE = pathlib.Path(__file__).parent
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, BUILTIN_DIR, read_mat
class Caltech101(Dataset):
......@@ -32,7 +30,7 @@ class Caltech101(Dataset):
return DatasetInfo(
"caltech101",
type=DatasetType.IMAGE,
categories=HERE / "caltech101.categories",
categories=BUILTIN_DIR / "caltech101.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
)
......@@ -135,12 +133,11 @@ class Caltech101(Dataset):
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, self._is_not_background_image)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
create_categories_file(HERE, self.name, sorted(dir_names))
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
class Caltech256(Dataset):
......@@ -149,7 +146,7 @@ class Caltech256(Dataset):
return DatasetInfo(
"caltech256",
type=DatasetType.IMAGE,
categories=HERE / "caltech256.categories",
categories=BUILTIN_DIR / "caltech256.categories",
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
)
......@@ -192,17 +189,8 @@ class Caltech256(Dataset):
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dir_names = {pathlib.Path(path).parent.name for path, _ in dp}
categories = [name.split(".")[1] for name in sorted(dir_names)]
create_categories_file(HERE, self.name, categories)
if __name__ == "__main__":
from torchvision.prototype.datasets import home
root = home()
Caltech101().generate_categories_file(root)
Caltech256().generate_categories_file(root)
return [name.split(".")[1] for name in sorted(dir_names)]
......@@ -24,16 +24,14 @@ from torchvision.prototype.datasets.utils import (
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
create_categories_file,
INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
image_buffer_from_array,
path_comparator,
)
__all__ = ["Cifar10", "Cifar100"]
HERE = pathlib.Path(__file__).parent
class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]):
def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None:
......@@ -95,13 +93,12 @@ class _CifarBase(Dataset):
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", self._META_FILE_NAME))
dp: IterDataPipe = Mapper(dp, self._unpickle)
categories = next(iter(dp))[self._CATEGORIES_KEY]
create_categories_file(HERE, self.name, categories)
return next(iter(dp))[self._CATEGORIES_KEY]
class Cifar10(_CifarBase):
......@@ -118,7 +115,7 @@ class Cifar10(_CifarBase):
return DatasetInfo(
"cifar10",
type=DatasetType.RAW,
categories=HERE / "cifar10.categories",
categories=BUILTIN_DIR / "cifar10.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
)
......@@ -145,7 +142,7 @@ class Cifar100(_CifarBase):
return DatasetInfo(
"cifar100",
type=DatasetType.RAW,
categories=HERE / "cifar100.categories",
categories=BUILTIN_DIR / "cifar100.categories",
homepage="https://www.cs.toronto.edu/~kriz/cifar.html",
valid_options=dict(
split=("train", "test"),
......@@ -159,11 +156,3 @@ class Cifar100(_CifarBase):
sha256="85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7",
)
]
if __name__ == "__main__":
from torchvision.prototype.datasets import home
root = home()
Cifar10().generate_categories_file(root)
Cifar100().generate_categories_file(root)
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
......@@ -24,16 +24,14 @@ from torchvision.prototype.datasets.utils import (
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
create_categories_file,
INFINITE_BUFFER_SIZE,
BUILTIN_DIR,
read_mat,
getitem,
path_accessor,
path_comparator,
)
HERE = pathlib.Path(__file__).parent
class SBD(Dataset):
@property
......@@ -41,7 +39,7 @@ class SBD(Dataset):
return DatasetInfo(
"sbd",
type=DatasetType.IMAGE,
categories=HERE / "caltech256.categories",
categories=BUILTIN_DIR / "caltech256.categories",
homepage="http://home.bharathh.info/pubs/codes/SBD/download.html",
valid_options=dict(
split=("train", "val", "train_noval"),
......@@ -158,7 +156,7 @@ class SBD(Dataset):
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
def generate_categories_file(self, root: Union[str, pathlib.Path]) -> None:
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
dp = self.resources(self.default_config)[0].to_datapipe(pathlib.Path(root) / self.name)
dp = TarArchiveReader(dp)
dp: IterDataPipe = Filter(dp, path_comparator("name", "category_names.m"))
......@@ -172,15 +170,4 @@ class SBD(Dataset):
# the first and last line contain no information
for line in lines[1:-1]
]
categories = tuple(
zip(*sorted(categories_and_labels, key=lambda category_and_label: int(category_and_label[1])))
)[0]
create_categories_file(HERE, self.name, categories)
if __name__ == "__main__":
from torchvision.prototype.datasets import home
root = home()
SBD().generate_categories_file(root)
return tuple(zip(*sorted(categories_and_labels, key=lambda category_and_label: int(category_and_label[1]))))[0]
import argparse
import sys
import unittest.mock
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=r"The categories file .+? does not exist.", category=UserWarning)
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import find
from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR
def main(*names, force=False):
root = datasets.home()
for name in names:
file = BUILTIN_DIR / f"{name}.categories"
if file.exists() and not force:
continue
dataset = find(name)
try:
with unittest.mock.patch(
"torchvision.prototype.datasets.utils._dataset.DatasetInfo._read_categories_file", return_value=[]
):
categories = dataset._generate_categories(root)
except NotImplementedError:
continue
with open(file, "w") as fh:
fh.write("\n".join(categories) + "\n")
def parse_args(argv=None):
parser = argparse.ArgumentParser(prog="torchvision.prototype.datasets.generate_category_files.py")
parser.add_argument(
"names",
nargs="?",
type=str,
help="Names of datasets to generate category files for. If omitted, all datasets will be used.",
)
parser.add_argument(
"-f",
"--force",
action="store_true",
help="Force regeneration of category files.",
)
args = parser.parse_args(argv or sys.argv[1:])
if not args.names:
args.names = datasets.list()
return args
if __name__ == "__main__":
args = parse_args()
try:
main(*args.names, force=args.force)
except Exception as error:
msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
sys.exit(1)
......@@ -4,6 +4,7 @@ import io
import os
import pathlib
import textwrap
import warnings
from collections import Mapping
from typing import (
Any,
......@@ -117,8 +118,7 @@ class DatasetInfo:
elif isinstance(categories, int):
categories = [str(label) for label in range(categories)]
elif isinstance(categories, (str, pathlib.Path)):
with open(pathlib.Path(categories).expanduser().resolve(), "r") as fh:
categories = [line.strip() for line in fh]
categories = self._read_categories_file(pathlib.Path(categories).expanduser().resolve())
self.categories = tuple(categories)
self.citation = citation
......@@ -137,6 +137,17 @@ class DatasetInfo:
)
self._valid_options: Dict[str, Sequence] = valid_options
@staticmethod
def _read_categories_file(path: pathlib.Path) -> List[str]:
if not path.exists() or not path.is_file():
warnings.warn(
f"The categories file {path} does not exist. Continuing without loaded categories.", UserWarning
)
return []
with open(path, "r") as file:
return [line.strip() for line in file]
@property
def default_config(self) -> DatasetConfig:
return DatasetConfig({name: valid_args[0] for name, valid_args in self._valid_options.items()})
......@@ -219,3 +230,6 @@ class Dataset(abc.ABC):
resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)]
return self._make_datapipe(resource_dps, config=config, decoder=decoder)
def _generate_categories(self, root: pathlib.Path) -> Sequence[str]:
raise NotImplementedError
......@@ -15,9 +15,9 @@ from torch.utils.data import IterDataPipe
__all__ = [
"INFINITE_BUFFER_SIZE",
"BUILTIN_DIR",
"sequence_to_str",
"add_suggestion",
"create_categories_file",
"read_mat",
"image_buffer_from_array",
"SequenceIterator",
......@@ -35,6 +35,8 @@ D = TypeVar("D")
# pseudo-infinite until a true infinite buffer is supported by all datapipes
INFINITE_BUFFER_SIZE = 1_000_000_000
BUILTIN_DIR = pathlib.Path(__file__).parent.parent / "_builtin"
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if len(seq) == 1:
......@@ -60,11 +62,6 @@ def add_suggestion(
return f"{msg.strip()} {hint}"
def create_categories_file(root: Union[str, pathlib.Path], name: str, categories: Sequence[str]) -> None:
with open(pathlib.Path(root) / f"{name}.categories", "w") as fh:
fh.write("\n".join(categories) + "\n")
def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any:
try:
import scipy.io as sio
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment