"official/r1/utils/logs/hooks_helper.py" did not exist on "1e2ceffd5a029f192b274e6b48430f948c5f5f55"
stanford_cars.py 4.34 KB
Newer Older
1
import pathlib
2
from typing import Any, Callable, Optional, Tuple
3
4
5

from PIL import Image

6
from .utils import verify_str_arg
7
8
9
10
from .vision import VisionDataset


class StanfordCars(VisionDataset):
11
    """Stanford Cars  Dataset
12
13
14
15
16

    The Cars dataset contains 16,185 images of 196 classes of cars. The data is
    split into 8,144 training images and 8,041 testing images, where each class
    has been split roughly in a 50-50 split

17
18
    The original URL is https://ai.stanford.edu/~jkrause/cars/car_dataset.html, but it is broken.

19
20
21
22
23
24
25
    .. note::

        This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

    Args:
        root (string): Root directory of dataset
        split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
anthony-cabacungan's avatar
anthony-cabacungan committed
26
        transform (callable, optional): A function/transform that takes in a PIL image
27
28
29
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
30
31
32
33
34
        download (bool, optional): This parameter exists for backward compatibility but it does not
            download the dataset, since the original URL is not available anymore. The dataset
            seems to be available on Kaggle so you can try to manually download it using
            `these instructions <https://github.com/pytorch/vision/issues/7545#issuecomment-1631441616>`_.
    """
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

    def __init__(
        self,
        root: str,
        split: str = "train",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:

        try:
            import scipy.io as sio
        except ImportError:
            raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")

        super().__init__(root, transform=transform, target_transform=target_transform)

        self._split = verify_str_arg(split, "split", ("train", "test"))
        self._base_folder = pathlib.Path(root) / "stanford_cars"
        devkit = self._base_folder / "devkit"

        if self._split == "train":
            self._annotations_mat_path = devkit / "cars_train_annos.mat"
            self._images_base_path = self._base_folder / "cars_train"
        else:
            self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
            self._images_base_path = self._base_folder / "cars_test"

        if download:
64
65
66
67
68
69
            raise ValueError(
                "The original URL is broken so the StanfordCars dataset is not available for automatic "
                "download anymore. You can try to download it manually following "
                "https://github.com/pytorch/vision/issues/7545#issuecomment-1631441616, "
                "and set download=False to avoid this error."
            )
70
71

        if not self._check_exists():
72
73
74
75
            raise RuntimeError(
                "Dataset not found. Try to manually download following the instructions in "
                "https://github.com/pytorch/vision/issues/7545#issuecomment-1631441616."
            )
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

        self._samples = [
            (
                str(self._images_base_path / annotation["fname"]),
                annotation["class"] - 1,  # Original target mapping  starts from 1, hence -1
            )
            for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
        ]

        self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

    def __len__(self) -> int:
        return len(self._samples)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        """Returns pil_image and class_id for given index"""
        image_path, target = self._samples[idx]
        pil_image = Image.open(image_path).convert("RGB")

        if self.transform is not None:
            pil_image = self.transform(pil_image)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return pil_image, target

    def _check_exists(self) -> bool:
        if not (self._base_folder / "devkit").is_dir():
            return False

        return self._annotations_mat_path.exists() and self._images_base_path.is_dir()