"src/libtorio/ffmpeg/pybind/pybind.cpp" did not exist on "4d2fa1908e531c6c815c026533b0e51a10ef9aef"
Unverified Commit 8c546f60 authored by Saswat Das's avatar Saswat Das Committed by GitHub
Browse files

Add SUN397 Dataset (#5132)



* dataset class added

* fix code format

* fixed requested changes

* fixed issues in sun397

* Update torchvision/datasets/sun397.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 5d5ecb45
......@@ -71,6 +71,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
SEMEION
Sintel
STL10
SUN397
SVHN
UCF101
USPS
......
......@@ -2206,6 +2206,52 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):
return len(sampled_classes * n_samples_per_class)
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SUN397
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "test"),
partition=(1, 10, None),
)
def inject_fake_data(self, tmpdir: str, config):
data_dir = pathlib.Path(tmpdir) / "SUN397"
data_dir.mkdir()
num_images_per_class = 5
sampled_classes = ("abbey", "airplane_cabin", "airport_terminal")
im_paths = []
for cls in sampled_classes:
image_folder = data_dir / cls[0]
im_paths.extend(
datasets_utils.create_image_folder(
image_folder,
image_folder / cls,
file_name_fn=lambda idx: f"sun_{idx}.jpg",
num_examples=num_images_per_class,
)
)
with open(data_dir / "ClassName.txt", "w") as file:
file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes))
if config["partition"] is not None:
num_samples = max(len(im_paths) // (2 if config["split"] == "train" else 3), 1)
with open(data_dir / f"{config['split'].title()}ing_{config['partition']:02d}.txt", "w") as file:
file.writelines(
"\n".join(
f"/{f_path.relative_to(data_dir).as_posix()}"
for f_path in random.choices(im_paths, k=num_samples)
)
)
else:
num_samples = len(im_paths)
return num_samples
class DTDTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DTD
FEATURE_TYPES = (PIL.Image.Image, int)
......
......@@ -28,6 +28,7 @@ from .sbd import SBDataset
from .sbu import SBU
from .semeion import SEMEION
from .stl10 import STL10
from .sun397 import SUN397
from .svhn import SVHN
from .ucf101 import UCF101
from .usps import USPS
......@@ -51,6 +52,7 @@ __all__ = (
"MNIST",
"KMNIST",
"STL10",
"SUN397",
"SVHN",
"PhotoTour",
"SEMEION",
......
from pathlib import Path
from typing import Any, Tuple, Callable, Optional
import PIL.Image
from .utils import verify_str_arg, download_and_extract_archive
from .vision import VisionDataset
class SUN397(VisionDataset):
"""`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.
The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
397 categories with 108'754 images. The dataset also provides 10 partitions for training
and testing, with each partition consisting of 50 images per class.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
partition (int, optional): A valid partition can be an integer from 1 to 10 or None,
for the entire dataset.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
_DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
_DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
_PARTITIONS_URL = "https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip"
_PARTITIONS_MD5 = "29a205c0a0129d21f36cbecfefe81881"
def __init__(
self,
root: str,
split: str = "train",
partition: Optional[int] = 1,
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.split = verify_str_arg(split, "split", ("train", "test"))
self.partition = partition
self._data_dir = Path(self.root) / "SUN397"
if self.partition is not None:
if self.partition < 0 or self.partition > 10:
raise RuntimeError(f"The partition parameter should be an int in [1, 10] or None, got {partition}.")
if download:
self._download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
with open(self._data_dir / "ClassName.txt") as f:
self.classes = [c[3:].strip() for c in f]
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
if self.partition is not None:
with open(self._data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f:
self._image_files = [self._data_dir.joinpath(*line.strip()[1:].split("/")) for line in f]
else:
self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
self._labels = [
self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
]
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def _check_exists(self) -> bool:
return self._data_dir.exists() and self._data_dir.is_dir()
def extra_repr(self) -> str:
return "Split: {split}".format(**self.__dict__)
def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)
download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self._data_dir), md5=self._PARTITIONS_MD5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment