Unverified Commit e047623a authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Cleanups for FLAVA datasets (#5164)

* Change default of download for Food101 and DTD

* Set download default to False and put it at the end

* Keep stuff private

* GTSRB: train -> split. Also use pathlib

* mypy

* Remove split and partition for SUN397

* mypy

* mypy

* move download param for SST2

* Use make_dataset in SST2

* Use a base URL for GTSRB

* Let's make this code more complictaed than it needs to be because why not
parent 563d9cad
......@@ -117,3 +117,7 @@ ignore_missing_imports = True
[mypy-torchdata.*]
ignore_missing_imports = True
[mypy-h5py.*]
ignore_missing_imports = True
......@@ -2281,11 +2281,6 @@ class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SUN397
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "test"),
partition=(1, 10, None),
)
def inject_fake_data(self, tmpdir: str, config):
data_dir = pathlib.Path(tmpdir) / "SUN397"
data_dir.mkdir()
......@@ -2308,17 +2303,6 @@ class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
with open(data_dir / "ClassName.txt", "w") as file:
file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes))
if config["partition"] is not None:
num_samples = max(len(im_paths) // (2 if config["split"] == "train" else 3), 1)
with open(data_dir / f"{config['split'].title()}ing_{config['partition']:02d}.txt", "w") as file:
file.writelines(
"\n".join(
f"/{f_path.relative_to(data_dir).as_posix()}"
for f_path in random.choices(im_paths, k=num_samples)
)
)
else:
num_samples = len(im_paths)
return num_samples
......@@ -2397,17 +2381,17 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.GTSRB
FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
def inject_fake_data(self, tmpdir: str, config):
root_folder = os.path.join(tmpdir, "GTSRB")
root_folder = os.path.join(tmpdir, "gtsrb")
os.makedirs(root_folder, exist_ok=True)
# Train data
train_folder = os.path.join(root_folder, "Training")
train_folder = os.path.join(root_folder, "GTSRB", "Training")
os.makedirs(train_folder, exist_ok=True)
num_examples = 3
num_examples = 3 if config["split"] == "train" else 4
classes = ("00000", "00042", "00012")
for class_idx in classes:
datasets_utils.create_image_folder(
......@@ -2419,7 +2403,7 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
total_number_of_examples = num_examples * len(classes)
# Test data
test_folder = os.path.join(root_folder, "Final_Test", "Images")
test_folder = os.path.join(root_folder, "GTSRB", "Final_Test", "Images")
os.makedirs(test_folder, exist_ok=True)
with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file:
......
......@@ -34,7 +34,7 @@ class CLEVRClassification(VisionDataset):
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
download: bool = False,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, transform=transform, target_transform=target_transform)
......
......@@ -32,7 +32,7 @@ class Country211(ImageFolder):
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
download: bool = False,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
......
......@@ -21,12 +21,12 @@ class DTD(VisionDataset):
The partition only changes which split each image belongs to. Thus, regardless of the selected
partition, combining all splits will result in all images.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
"""
_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
......@@ -37,9 +37,9 @@ class DTD(VisionDataset):
root: str,
split: str = "train",
partition: int = 1,
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
if not isinstance(partition, int) and not (1 <= partition <= 10):
......
import os
from typing import Any
from typing import Callable, Optional
from .folder import ImageFolder
from .utils import download_and_extract_archive
......@@ -10,23 +10,21 @@ class EuroSAT(ImageFolder):
Args:
root (string): Root directory of dataset where ``root/eurosat`` exists.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
"""
url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip"
md5 = "c8fa014336c82ac7804f0398fcb19387"
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
**kwargs: Any,
) -> None:
self.root = os.path.expanduser(root)
self._base_folder = os.path.join(self.root, "eurosat")
......@@ -38,7 +36,7 @@ class EuroSAT(ImageFolder):
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
super().__init__(self._data_folder, **kwargs)
super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
self.root = os.path.expanduser(root)
def __len__(self) -> int:
......@@ -53,4 +51,8 @@ class EuroSAT(ImageFolder):
return
os.makedirs(self._base_folder, exist_ok=True)
download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5)
download_and_extract_archive(
"https://madm.dfki.de/files/sentinel/EuroSAT.zip",
download_root=self._base_folder,
md5="c8fa014336c82ac7804f0398fcb19387",
)
......@@ -26,15 +26,15 @@ class FGVCAircraft(VisionDataset):
root (string): Root directory of the FGVC Aircraft dataset.
split (string, optional): The dataset split, supports ``train``, ``val``,
``trainval`` and ``test``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
annotation_level (str, optional): The annotation level, supports ``variant``,
``family`` and ``manufacturer``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
......@@ -43,10 +43,10 @@ class FGVCAircraft(VisionDataset):
self,
root: str,
split: str = "trainval",
download: bool = False,
annotation_level: str = "variant",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
......
......@@ -24,12 +24,12 @@ class Flowers102(VisionDataset):
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a
transformed version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
_download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
......@@ -44,9 +44,9 @@ class Flowers102(VisionDataset):
self,
root: str,
split: str = "train",
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
......
......@@ -24,6 +24,9 @@ class Food101(VisionDataset):
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
"""
_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
......@@ -33,9 +36,9 @@ class Food101(VisionDataset):
self,
root: str,
split: str = "train",
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "test"))
......
import csv
import os
import pathlib
from typing import Any, Callable, Optional, Tuple
import PIL
from .folder import make_dataset
from .utils import download_and_extract_archive
from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
......@@ -14,8 +14,7 @@ class GTSRB(VisionDataset):
Args:
root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
......@@ -24,23 +23,10 @@ class GTSRB(VisionDataset):
downloaded again.
"""
# Ground Truth for the test set
_gt_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip"
_gt_csv = "GT-final_test.csv"
_gt_md5 = "fe31e9c9270bbcd7b84b7f21a9d9d9e5"
# URLs for the test and train set
_urls = (
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip",
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip",
)
_md5s = ("c7e4e6327067d32654124b0fe9e82185", "513f3c79a4c5141765e10e952eaa2478")
def __init__(
self,
root: str,
train: bool = True,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
......@@ -48,12 +34,11 @@ class GTSRB(VisionDataset):
super().__init__(root, transform=transform, target_transform=target_transform)
self.root = os.path.expanduser(root)
self.train = train
self._base_folder = os.path.join(self.root, type(self).__name__)
self._target_folder = os.path.join(self._base_folder, "Training" if self.train else "Final_Test/Images")
self._split = verify_str_arg(split, "split", ("train", "test"))
self._base_folder = pathlib.Path(root) / "gtsrb"
self._target_folder = (
self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
)
if download:
self.download()
......@@ -61,12 +46,12 @@ class GTSRB(VisionDataset):
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
if train:
samples = make_dataset(self._target_folder, extensions=(".ppm",))
if self._split == "train":
samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
else:
with open(os.path.join(self._base_folder, self._gt_csv)) as csv_file:
with open(self._base_folder / "GT-final_test.csv") as csv_file:
samples = [
(os.path.join(self._target_folder, row["Filename"]), int(row["ClassId"]))
(str(self._target_folder / row["Filename"]), int(row["ClassId"]))
for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
]
......@@ -91,16 +76,28 @@ class GTSRB(VisionDataset):
return sample, target
def _check_exists(self) -> bool:
return os.path.exists(self._target_folder) and os.path.isdir(self._target_folder)
return self._target_folder.is_dir()
def download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._urls[self.train], download_root=self.root, md5=self._md5s[self.train])
base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
if not self.train:
# Download Ground Truth for the test set
if self._split == "train":
download_and_extract_archive(
f"{base_url}GTSRB-Training_fixed.zip",
download_root=str(self._base_folder),
md5="513f3c79a4c5141765e10e952eaa2478",
)
else:
download_and_extract_archive(
f"{base_url}GTSRB_Final_Test_Images.zip",
download_root=str(self._base_folder),
md5="c7e4e6327067d32654124b0fe9e82185",
)
download_and_extract_archive(
self._gt_url, download_root=self.root, extract_root=self._base_folder, md5=self._gt_md5
f"{base_url}GTSRB_Final_Test_GT.zip",
download_root=str(self._base_folder),
md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
)
......@@ -45,7 +45,7 @@ class OxfordIIITPet(VisionDataset):
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
download: bool = False,
):
self._split = verify_str_arg(split, "split", ("trainval", "test"))
if isinstance(target_types, str):
......
......@@ -72,10 +72,10 @@ class PCAM(VisionDataset):
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
download: bool = False,
):
try:
import h5py # type: ignore[import]
import h5py
self.h5py = h5py
except ImportError:
......
......@@ -3,6 +3,7 @@ from typing import Any, Tuple, Callable, Optional
import PIL.Image
from .folder import make_dataset
from .utils import verify_str_arg, download_and_extract_archive
from .vision import VisionDataset
......@@ -21,12 +22,12 @@ class RenderedSST2(VisionDataset):
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
"""
_URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz"
......@@ -36,9 +37,9 @@ class RenderedSST2(VisionDataset):
self,
root: str,
split: str = "train",
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
......@@ -53,18 +54,13 @@ class RenderedSST2(VisionDataset):
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self._labels = []
self._image_files = []
for p in (self._base_folder / self._split_to_folder[self._split]).glob("**/*.png"):
self._labels.append(self.class_to_idx[p.parent.name])
self._image_files.append(p)
self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions=("png",))
def __len__(self) -> int:
return len(self._image_files)
return len(self._samples)
def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image_file, label = self._samples[idx]
image = PIL.Image.open(image_file).convert("RGB")
if self.transform:
......
......@@ -3,7 +3,7 @@ from typing import Any, Tuple, Callable, Optional
import PIL.Image
from .utils import verify_str_arg, download_and_extract_archive
from .utils import download_and_extract_archive
from .vision import VisionDataset
......@@ -11,45 +11,31 @@ class SUN397(VisionDataset):
"""`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.
The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
397 categories with 108'754 images. The dataset also provides 10 partitions for training
and testing, with each partition consisting of 50 images per class.
397 categories with 108'754 images.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
partition (int, optional): A valid partition can be an integer from 1 to 10 or None,
for the entire dataset.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
_DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
_DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
_PARTITIONS_URL = "https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip"
_PARTITIONS_MD5 = "29a205c0a0129d21f36cbecfefe81881"
def __init__(
self,
root: str,
split: str = "train",
partition: Optional[int] = 1,
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.split = verify_str_arg(split, "split", ("train", "test"))
self.partition = partition
self._data_dir = Path(self.root) / "SUN397"
if self.partition is not None:
if self.partition < 0 or self.partition > 10:
raise RuntimeError(f"The partition parameter should be an int in [1, 10] or None, got {partition}.")
if download:
self._download()
......@@ -60,10 +46,6 @@ class SUN397(VisionDataset):
self.classes = [c[3:].strip() for c in f]
self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
if self.partition is not None:
with open(self._data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f:
self._image_files = [self._data_dir.joinpath(*line.strip()[1:].split("/")) for line in f]
else:
self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
self._labels = [
......@@ -86,13 +68,9 @@ class SUN397(VisionDataset):
return image, label
def _check_exists(self) -> bool:
return self._data_dir.exists() and self._data_dir.is_dir()
def extra_repr(self) -> str:
return "Split: {split}".format(**self.__dict__)
return self._data_dir.is_dir()
def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)
download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self._data_dir), md5=self._PARTITIONS_MD5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment