Unverified Commit 8ba482a6 authored by Abhijit Deo's avatar Abhijit Deo Committed by GitHub
Browse files

Add support for Stanford cars dataset (#5166)



* [WIP]
*added stanford_cars

* [WIP]
added stanfordCars to docs

* [WIP]
minor edits

* [WIP]
minor edits

* edited StanfordCars class

* Adding Testcase for stanford cars

* Added Testcase for stanford cars

* Added Testcase for stanford cars

* minor edit

* made changes as per the suggestions

* fixed typo in naming stanford_cars.py

* cars_meta.mat file will be created in test

* Some cleanups

* Sigh

* don't convert to strings
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 57a77c45
......@@ -75,6 +75,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
SBU
SEMEION
Sintel
StanfordCars
STL10
SUN397
SVHN
......
......@@ -2535,6 +2535,50 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
return (image_id, class_id, species, breed_id)
class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.StanfordCars
REQUIRED_PACKAGES = ("scipy",)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
def inject_fake_data(self, tmpdir, config):
import scipy.io as io
from numpy.core.records import fromarrays
num_examples = {"train": 5, "test": 7}[config["split"]]
num_classes = 3
base_folder = pathlib.Path(tmpdir) / "stanford_cars"
devkit = base_folder / "devkit"
devkit.mkdir(parents=True)
if config["split"] == "train":
images_folder_name = "cars_train"
annotations_mat_path = devkit / "cars_train_annos.mat"
else:
images_folder_name = "cars_test"
annotations_mat_path = base_folder / "cars_test_annos_withlabels.mat"
datasets_utils.create_image_folder(
root=base_folder,
name=images_folder_name,
file_name_fn=lambda image_index: f"{image_index:5d}.jpg",
num_examples=num_examples,
)
classes = np.random.randint(1, num_classes + 1, num_examples, dtype=np.uint8)
fnames = [f"{i:5d}.jpg" for i in range(num_examples)]
rec_array = fromarrays(
[classes, fnames],
names=["class", "fname"],
)
io.savemat(annotations_mat_path, {"annotations": rec_array})
random_class_names = ["random_name"] * num_classes
io.savemat(devkit / "cars_meta.mat", {"class_names": random_class_names})
return num_examples
class Country211TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Country211
......
......@@ -32,6 +32,7 @@ from .places365 import Places365
from .sbd import SBDataset
from .sbu import SBU
from .semeion import SEMEION
from .stanford_cars import StanfordCars
from .stl10 import STL10
from .sun397 import SUN397
from .svhn import SVHN
......@@ -56,6 +57,7 @@ __all__ = (
"QMNIST",
"MNIST",
"KMNIST",
"StanfordCars",
"STL10",
"SUN397",
"SVHN",
......
import pathlib
from typing import Callable, Optional, Any, Tuple
from PIL import Image
from .utils import download_and_extract_archive, download_url, verify_str_arg
from .vision import VisionDataset
class StanfordCars(VisionDataset):
"""`Stanford Cars <https://ai.stanford.edu/~jkrause/cars/car_dataset.html>`_ Dataset
The Cars dataset contains 16,185 images of 196 classes of cars. The data is
split into 8,144 training images and 8,041 testing images, where each class
has been split roughly in a 50-50 split
.. note::
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.
Args:
root (string): Root directory of dataset
split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again."""
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
try:
import scipy.io as sio
except ImportError:
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "test"))
self._base_folder = pathlib.Path(root) / "stanford_cars"
devkit = self._base_folder / "devkit"
if self._split == "train":
self._annotations_mat_path = devkit / "cars_train_annos.mat"
self._images_base_path = self._base_folder / "cars_train"
else:
self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
self._images_base_path = self._base_folder / "cars_test"
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self._samples = [
(
str(self._images_base_path / annotation["fname"]),
annotation["class"] - 1, # Original target mapping starts from 1, hence -1
)
for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
]
self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
"""Returns pil_image and class_id for given index"""
image_path, target = self._samples[idx]
pil_image = Image.open(image_path).convert("RGB")
if self.transform is not None:
pil_image = self.transform(pil_image)
if self.target_transform is not None:
target = self.target_transform(target)
return pil_image, target
def download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(
url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
download_root=str(self._base_folder),
md5="c3b158d763b6e2245038c8ad08e45376",
)
if self._split == "train":
download_and_extract_archive(
url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
download_root=str(self._base_folder),
md5="065e5b463ae28d29e77c1b4b166cfe61",
)
else:
download_and_extract_archive(
url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
download_root=str(self._base_folder),
md5="4ce7ebf6a94d07f1952d94dd34c4d501",
)
download_url(
url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
root=str(self._base_folder),
md5="b0a2b23655a3edd16d84508592a98d10",
)
def _check_exists(self) -> bool:
if not (self._base_folder / "devkit").is_dir():
return False
return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment