Unverified Commit adf8466e authored by Yiwen Song's avatar Yiwen Song Committed by GitHub
Browse files

Adding fvgc_aircraft dataset (#5178)



* add fvgc_aircraft dataset

* add docstring & remove useless import

* resolve lint issue

* address comments

* adding more annotation level

* nit

* address comments

* Apply suggestions from code review

* unify format

* remove useless line
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 1feb6376
...@@ -50,6 +50,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas ...@@ -50,6 +50,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
FlyingChairs FlyingChairs
FlyingThings3D FlyingThings3D
Food101 Food101
FGVCAircraft
GTSRB GTSRB
HD1K HD1K
HMDB51 HMDB51
......
...@@ -2206,6 +2206,57 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2206,6 +2206,57 @@ class Food101TestCase(datasets_utils.ImageDatasetTestCase):
return len(sampled_classes * n_samples_per_class) return len(sampled_classes * n_samples_per_class)
class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FGVCAircraft
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "val", "trainval", "test"), annotation_level=("variant", "family", "manufacturer")
)
def inject_fake_data(self, tmpdir: str, config):
split = config["split"]
annotation_level = config["annotation_level"]
annotation_level_to_file = {
"variant": "variants.txt",
"family": "families.txt",
"manufacturer": "manufacturers.txt",
}
root_folder = pathlib.Path(tmpdir) / "fgvc-aircraft-2013b"
data_folder = root_folder / "data"
classes = ["707-320", "Hawk T1", "Tornado"]
num_images_per_class = 5
datasets_utils.create_image_folder(
data_folder,
"images",
file_name_fn=lambda idx: f"{idx}.jpg",
num_examples=num_images_per_class * len(classes),
)
annotation_file = data_folder / annotation_level_to_file[annotation_level]
with open(annotation_file, "w") as file:
file.write("\n".join(classes))
num_samples_per_class = 4 if split == "trainval" else 2
images_classes = []
for i in range(len(classes)):
images_classes.extend(
[
f"{idx} {classes[i]}"
for idx in random.sample(
range(i * num_images_per_class, (i + 1) * num_images_per_class), num_samples_per_class
)
]
)
images_annotation_file = data_folder / f"images_{annotation_level}_{split}.txt"
with open(images_annotation_file, "w") as file:
file.write("\n".join(images_classes))
return len(classes * num_samples_per_class)
class SUN397TestCase(datasets_utils.ImageDatasetTestCase): class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SUN397 DATASET_CLASS = datasets.SUN397
......
...@@ -9,6 +9,7 @@ from .country211 import Country211 ...@@ -9,6 +9,7 @@ from .country211 import Country211
from .dtd import DTD from .dtd import DTD
from .fakedata import FakeData from .fakedata import FakeData
from .fer2013 import FER2013 from .fer2013 import FER2013
from .fgvc_aircraft import FGVCAircraft
from .flickr import Flickr8k, Flickr30k from .flickr import Flickr8k, Flickr30k
from .flowers102 import Flowers102 from .flowers102 import Flowers102
from .folder import ImageFolder, DatasetFolder from .folder import ImageFolder, DatasetFolder
...@@ -95,4 +96,5 @@ __all__ = ( ...@@ -95,4 +96,5 @@ __all__ = (
"CLEVRClassification", "CLEVRClassification",
"OxfordIIITPet", "OxfordIIITPet",
"Country211", "Country211",
"FGVCAircraft",
) )
from __future__ import annotations
import os
from typing import Any, Callable, Optional, Tuple
import PIL.Image
from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
class FGVCAircraft(VisionDataset):
"""`FGVC Aircraft <https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/>`_ Dataset.
The dataset contains 10,200 images of aircraft, with 100 images for each of 102
different aircraft model variants, most of which are airplanes.
Aircraft models are organized in a three-levels hierarchy. The three levels, from
finer to coarser, are:
- ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually
indistinguishable into one class. The dataset comprises 102 different variants.
- ``family``, e.g. Boeing 737. The dataset comprises 70 different families.
- ``manufacturer``, e.g. Boeing. The dataset comprises 41 different manufacturers.
Args:
root (string): Root directory of the FGVC Aircraft dataset.
split (string, optional): The dataset split, supports ``train``, ``val``,
``trainval`` and ``test``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
annotation_level (str, optional): The annotation level, supports ``variant``,
``family`` and ``manufacturer``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
def __init__(
self,
root: str,
split: str = "trainval",
download: bool = False,
annotation_level: str = "variant",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
self._annotation_level = verify_str_arg(
annotation_level, "annotation_level", ("variant", "family", "manufacturer")
)
self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
if download:
self._download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
annotation_file = os.path.join(
self._data_path,
"data",
{
"variant": "variants.txt",
"family": "families.txt",
"manufacturer": "manufacturers.txt",
}[self._annotation_level],
)
with open(annotation_file, "r") as f:
self.classes = [line.strip() for line in f]
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
image_data_folder = os.path.join(self._data_path, "data", "images")
labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt")
self._image_files = []
self._labels = []
with open(labels_file, "r") as f:
for line in f:
image_name, label_name = line.strip().split(" ", 1)
self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
self._labels.append(self.class_to_idx[label_name])
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def _download(self) -> None:
"""
Download the FGVC Aircraft dataset archive and extract it under root.
"""
if self._check_exists():
return
download_and_extract_archive(self._URL, self.root)
def _check_exists(self) -> bool:
return os.path.exists(self._data_path) and os.path.isdir(self._data_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment