Unverified Commit e047623a authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Cleanups for FLAVA datasets (#5164)

* Change default of download for Food101 and DTD

* Set download default to False and put it at the end

* Keep stuff private

* GTSRB: train -> split. Also use pathlib

* mypy

* Remove split and partition for SUN397

* mypy

* mypy

* move download param for SST2

* Use make_dataset in SST2

* Use a base URL for GTSRB

* Let's make this code more complictaed than it needs to be because why not
parent 563d9cad
...@@ -117,3 +117,7 @@ ignore_missing_imports = True ...@@ -117,3 +117,7 @@ ignore_missing_imports = True
[mypy-torchdata.*] [mypy-torchdata.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-h5py.*]
ignore_missing_imports = True
...@@ -2281,11 +2281,6 @@ class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2281,11 +2281,6 @@ class FGVCAircraftTestCase(datasets_utils.ImageDatasetTestCase):
class SUN397TestCase(datasets_utils.ImageDatasetTestCase): class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SUN397 DATASET_CLASS = datasets.SUN397
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "test"),
partition=(1, 10, None),
)
def inject_fake_data(self, tmpdir: str, config): def inject_fake_data(self, tmpdir: str, config):
data_dir = pathlib.Path(tmpdir) / "SUN397" data_dir = pathlib.Path(tmpdir) / "SUN397"
data_dir.mkdir() data_dir.mkdir()
...@@ -2308,18 +2303,7 @@ class SUN397TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2308,18 +2303,7 @@ class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
with open(data_dir / "ClassName.txt", "w") as file: with open(data_dir / "ClassName.txt", "w") as file:
file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes)) file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes))
if config["partition"] is not None: num_samples = len(im_paths)
num_samples = max(len(im_paths) // (2 if config["split"] == "train" else 3), 1)
with open(data_dir / f"{config['split'].title()}ing_{config['partition']:02d}.txt", "w") as file:
file.writelines(
"\n".join(
f"/{f_path.relative_to(data_dir).as_posix()}"
for f_path in random.choices(im_paths, k=num_samples)
)
)
else:
num_samples = len(im_paths)
return num_samples return num_samples
...@@ -2397,17 +2381,17 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2397,17 +2381,17 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.GTSRB DATASET_CLASS = datasets.GTSRB
FEATURE_TYPES = (PIL.Image.Image, int) FEATURE_TYPES = (PIL.Image.Image, int)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False)) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
def inject_fake_data(self, tmpdir: str, config): def inject_fake_data(self, tmpdir: str, config):
root_folder = os.path.join(tmpdir, "GTSRB") root_folder = os.path.join(tmpdir, "gtsrb")
os.makedirs(root_folder, exist_ok=True) os.makedirs(root_folder, exist_ok=True)
# Train data # Train data
train_folder = os.path.join(root_folder, "Training") train_folder = os.path.join(root_folder, "GTSRB", "Training")
os.makedirs(train_folder, exist_ok=True) os.makedirs(train_folder, exist_ok=True)
num_examples = 3 num_examples = 3 if config["split"] == "train" else 4
classes = ("00000", "00042", "00012") classes = ("00000", "00042", "00012")
for class_idx in classes: for class_idx in classes:
datasets_utils.create_image_folder( datasets_utils.create_image_folder(
...@@ -2419,7 +2403,7 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2419,7 +2403,7 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
total_number_of_examples = num_examples * len(classes) total_number_of_examples = num_examples * len(classes)
# Test data # Test data
test_folder = os.path.join(root_folder, "Final_Test", "Images") test_folder = os.path.join(root_folder, "GTSRB", "Final_Test", "Images")
os.makedirs(test_folder, exist_ok=True) os.makedirs(test_folder, exist_ok=True)
with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file: with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file:
......
...@@ -34,7 +34,7 @@ class CLEVRClassification(VisionDataset): ...@@ -34,7 +34,7 @@ class CLEVRClassification(VisionDataset):
split: str = "train", split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = True, download: bool = False,
) -> None: ) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test")) self._split = verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
......
...@@ -32,7 +32,7 @@ class Country211(ImageFolder): ...@@ -32,7 +32,7 @@ class Country211(ImageFolder):
split: str = "train", split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = True, download: bool = False,
) -> None: ) -> None:
self._split = verify_str_arg(split, "split", ("train", "valid", "test")) self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
......
...@@ -21,12 +21,12 @@ class DTD(VisionDataset): ...@@ -21,12 +21,12 @@ class DTD(VisionDataset):
The partition only changes which split each image belongs to. Thus, regardless of the selected The partition only changes which split each image belongs to. Thus, regardless of the selected
partition, combining all splits will result in all images. partition, combining all splits will result in all images.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``. version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
""" """
_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz" _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
...@@ -37,9 +37,9 @@ class DTD(VisionDataset): ...@@ -37,9 +37,9 @@ class DTD(VisionDataset):
root: str, root: str,
split: str = "train", split: str = "train",
partition: int = 1, partition: int = 1,
download: bool = True,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False,
) -> None: ) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test")) self._split = verify_str_arg(split, "split", ("train", "val", "test"))
if not isinstance(partition, int) and not (1 <= partition <= 10): if not isinstance(partition, int) and not (1 <= partition <= 10):
......
import os import os
from typing import Any from typing import Callable, Optional
from .folder import ImageFolder from .folder import ImageFolder
from .utils import download_and_extract_archive from .utils import download_and_extract_archive
...@@ -10,23 +10,21 @@ class EuroSAT(ImageFolder): ...@@ -10,23 +10,21 @@ class EuroSAT(ImageFolder):
Args: Args:
root (string): Root directory of dataset where ``root/eurosat`` exists. root (string): Root directory of dataset where ``root/eurosat`` exists.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
transform (callable, optional): A function/transform that takes in an PIL image transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop`` and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
""" """
url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip"
md5 = "c8fa014336c82ac7804f0398fcb19387"
def __init__( def __init__(
self, self,
root: str, root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
**kwargs: Any,
) -> None: ) -> None:
self.root = os.path.expanduser(root) self.root = os.path.expanduser(root)
self._base_folder = os.path.join(self.root, "eurosat") self._base_folder = os.path.join(self.root, "eurosat")
...@@ -38,7 +36,7 @@ class EuroSAT(ImageFolder): ...@@ -38,7 +36,7 @@ class EuroSAT(ImageFolder):
if not self._check_exists(): if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it") raise RuntimeError("Dataset not found. You can use download=True to download it")
super().__init__(self._data_folder, **kwargs) super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
self.root = os.path.expanduser(root) self.root = os.path.expanduser(root)
def __len__(self) -> int: def __len__(self) -> int:
...@@ -53,4 +51,8 @@ class EuroSAT(ImageFolder): ...@@ -53,4 +51,8 @@ class EuroSAT(ImageFolder):
return return
os.makedirs(self._base_folder, exist_ok=True) os.makedirs(self._base_folder, exist_ok=True)
download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5) download_and_extract_archive(
"https://madm.dfki.de/files/sentinel/EuroSAT.zip",
download_root=self._base_folder,
md5="c8fa014336c82ac7804f0398fcb19387",
)
...@@ -26,15 +26,15 @@ class FGVCAircraft(VisionDataset): ...@@ -26,15 +26,15 @@ class FGVCAircraft(VisionDataset):
root (string): Root directory of the FGVC Aircraft dataset. root (string): Root directory of the FGVC Aircraft dataset.
split (string, optional): The dataset split, supports ``train``, ``val``, split (string, optional): The dataset split, supports ``train``, ``val``,
``trainval`` and ``test``. ``trainval`` and ``test``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
annotation_level (str, optional): The annotation level, supports ``variant``, annotation_level (str, optional): The annotation level, supports ``variant``,
``family`` and ``manufacturer``. ``family`` and ``manufacturer``.
transform (callable, optional): A function/transform that takes in an PIL image transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop`` and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
""" """
_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz" _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
...@@ -43,10 +43,10 @@ class FGVCAircraft(VisionDataset): ...@@ -43,10 +43,10 @@ class FGVCAircraft(VisionDataset):
self, self,
root: str, root: str,
split: str = "trainval", split: str = "trainval",
download: bool = False,
annotation_level: str = "variant", annotation_level: str = "variant",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False,
) -> None: ) -> None:
super().__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test")) self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
......
...@@ -24,12 +24,12 @@ class Flowers102(VisionDataset): ...@@ -24,12 +24,12 @@ class Flowers102(VisionDataset):
Args: Args:
root (string): Root directory of the dataset. root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transform (callable, optional): A function/transform that takes in an PIL image and returns a
transformed version. E.g, ``transforms.RandomCrop``. transformed version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
""" """
_download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/" _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
...@@ -44,9 +44,9 @@ class Flowers102(VisionDataset): ...@@ -44,9 +44,9 @@ class Flowers102(VisionDataset):
self, self,
root: str, root: str,
split: str = "train", split: str = "train",
download: bool = True,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False,
) -> None: ) -> None:
super().__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "test")) self._split = verify_str_arg(split, "split", ("train", "val", "test"))
......
...@@ -24,6 +24,9 @@ class Food101(VisionDataset): ...@@ -24,6 +24,9 @@ class Food101(VisionDataset):
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``. version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
""" """
_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz" _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
...@@ -33,9 +36,9 @@ class Food101(VisionDataset): ...@@ -33,9 +36,9 @@ class Food101(VisionDataset):
self, self,
root: str, root: str,
split: str = "train", split: str = "train",
download: bool = True,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False,
) -> None: ) -> None:
super().__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "test")) self._split = verify_str_arg(split, "split", ("train", "test"))
......
import csv import csv
import os import pathlib
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
import PIL import PIL
from .folder import make_dataset from .folder import make_dataset
from .utils import download_and_extract_archive from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset from .vision import VisionDataset
...@@ -14,8 +14,7 @@ class GTSRB(VisionDataset): ...@@ -14,8 +14,7 @@ class GTSRB(VisionDataset):
Args: Args:
root (string): Root directory of the dataset. root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``. version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
...@@ -24,23 +23,10 @@ class GTSRB(VisionDataset): ...@@ -24,23 +23,10 @@ class GTSRB(VisionDataset):
downloaded again. downloaded again.
""" """
# Ground Truth for the test set
_gt_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip"
_gt_csv = "GT-final_test.csv"
_gt_md5 = "fe31e9c9270bbcd7b84b7f21a9d9d9e5"
# URLs for the test and train set
_urls = (
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip",
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip",
)
_md5s = ("c7e4e6327067d32654124b0fe9e82185", "513f3c79a4c5141765e10e952eaa2478")
def __init__( def __init__(
self, self,
root: str, root: str,
train: bool = True, split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
...@@ -48,12 +34,11 @@ class GTSRB(VisionDataset): ...@@ -48,12 +34,11 @@ class GTSRB(VisionDataset):
super().__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.root = os.path.expanduser(root) self._split = verify_str_arg(split, "split", ("train", "test"))
self._base_folder = pathlib.Path(root) / "gtsrb"
self.train = train self._target_folder = (
self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
self._base_folder = os.path.join(self.root, type(self).__name__) )
self._target_folder = os.path.join(self._base_folder, "Training" if self.train else "Final_Test/Images")
if download: if download:
self.download() self.download()
...@@ -61,12 +46,12 @@ class GTSRB(VisionDataset): ...@@ -61,12 +46,12 @@ class GTSRB(VisionDataset):
if not self._check_exists(): if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it") raise RuntimeError("Dataset not found. You can use download=True to download it")
if train: if self._split == "train":
samples = make_dataset(self._target_folder, extensions=(".ppm",)) samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
else: else:
with open(os.path.join(self._base_folder, self._gt_csv)) as csv_file: with open(self._base_folder / "GT-final_test.csv") as csv_file:
samples = [ samples = [
(os.path.join(self._target_folder, row["Filename"]), int(row["ClassId"])) (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
] ]
...@@ -91,16 +76,28 @@ class GTSRB(VisionDataset): ...@@ -91,16 +76,28 @@ class GTSRB(VisionDataset):
return sample, target return sample, target
def _check_exists(self) -> bool: def _check_exists(self) -> bool:
return os.path.exists(self._target_folder) and os.path.isdir(self._target_folder) return self._target_folder.is_dir()
def download(self) -> None: def download(self) -> None:
if self._check_exists(): if self._check_exists():
return return
download_and_extract_archive(self._urls[self.train], download_root=self.root, md5=self._md5s[self.train]) base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
if not self.train: if self._split == "train":
# Download Ground Truth for the test set download_and_extract_archive(
f"{base_url}GTSRB-Training_fixed.zip",
download_root=str(self._base_folder),
md5="513f3c79a4c5141765e10e952eaa2478",
)
else:
download_and_extract_archive(
f"{base_url}GTSRB_Final_Test_Images.zip",
download_root=str(self._base_folder),
md5="c7e4e6327067d32654124b0fe9e82185",
)
download_and_extract_archive( download_and_extract_archive(
self._gt_url, download_root=self.root, extract_root=self._base_folder, md5=self._gt_md5 f"{base_url}GTSRB_Final_Test_GT.zip",
download_root=str(self._base_folder),
md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
) )
...@@ -45,7 +45,7 @@ class OxfordIIITPet(VisionDataset): ...@@ -45,7 +45,7 @@ class OxfordIIITPet(VisionDataset):
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = True, download: bool = False,
): ):
self._split = verify_str_arg(split, "split", ("trainval", "test")) self._split = verify_str_arg(split, "split", ("trainval", "test"))
if isinstance(target_types, str): if isinstance(target_types, str):
......
...@@ -72,10 +72,10 @@ class PCAM(VisionDataset): ...@@ -72,10 +72,10 @@ class PCAM(VisionDataset):
split: str = "train", split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = True, download: bool = False,
): ):
try: try:
import h5py # type: ignore[import] import h5py
self.h5py = h5py self.h5py = h5py
except ImportError: except ImportError:
......
...@@ -3,6 +3,7 @@ from typing import Any, Tuple, Callable, Optional ...@@ -3,6 +3,7 @@ from typing import Any, Tuple, Callable, Optional
import PIL.Image import PIL.Image
from .folder import make_dataset
from .utils import verify_str_arg, download_and_extract_archive from .utils import verify_str_arg, download_and_extract_archive
from .vision import VisionDataset from .vision import VisionDataset
...@@ -21,12 +22,12 @@ class RenderedSST2(VisionDataset): ...@@ -21,12 +22,12 @@ class RenderedSST2(VisionDataset):
Args: Args:
root (string): Root directory of the dataset. root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``. split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``. version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
""" """
_URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz" _URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz"
...@@ -36,9 +37,9 @@ class RenderedSST2(VisionDataset): ...@@ -36,9 +37,9 @@ class RenderedSST2(VisionDataset):
self, self,
root: str, root: str,
split: str = "train", split: str = "train",
download: bool = False,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False,
) -> None: ) -> None:
super().__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "test")) self._split = verify_str_arg(split, "split", ("train", "val", "test"))
...@@ -53,18 +54,13 @@ class RenderedSST2(VisionDataset): ...@@ -53,18 +54,13 @@ class RenderedSST2(VisionDataset):
if not self._check_exists(): if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it") raise RuntimeError("Dataset not found. You can use download=True to download it")
self._labels = [] self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions=("png",))
self._image_files = []
for p in (self._base_folder / self._split_to_folder[self._split]).glob("**/*.png"):
self._labels.append(self.class_to_idx[p.parent.name])
self._image_files.append(p)
def __len__(self) -> int: def __len__(self) -> int:
return len(self._image_files) return len(self._samples)
def __getitem__(self, idx) -> Tuple[Any, Any]: def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx] image_file, label = self._samples[idx]
image = PIL.Image.open(image_file).convert("RGB") image = PIL.Image.open(image_file).convert("RGB")
if self.transform: if self.transform:
......
...@@ -3,7 +3,7 @@ from typing import Any, Tuple, Callable, Optional ...@@ -3,7 +3,7 @@ from typing import Any, Tuple, Callable, Optional
import PIL.Image import PIL.Image
from .utils import verify_str_arg, download_and_extract_archive from .utils import download_and_extract_archive
from .vision import VisionDataset from .vision import VisionDataset
...@@ -11,45 +11,31 @@ class SUN397(VisionDataset): ...@@ -11,45 +11,31 @@ class SUN397(VisionDataset):
"""`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_. """`The SUN397 Data Set <https://vision.princeton.edu/projects/2010/SUN/>`_.
The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
397 categories with 108'754 images. The dataset also provides 10 partitions for training 397 categories with 108'754 images.
and testing, with each partition consisting of 50 images per class.
Args: Args:
root (string): Root directory of the dataset. root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
partition (int, optional): A valid partition can be an integer from 1 to 10 or None,
for the entire dataset.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``. version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it. target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
""" """
_DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz" _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
_DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a" _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
_PARTITIONS_URL = "https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip"
_PARTITIONS_MD5 = "29a205c0a0129d21f36cbecfefe81881"
def __init__( def __init__(
self, self,
root: str, root: str,
split: str = "train",
partition: Optional[int] = 1,
download: bool = True,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False,
) -> None: ) -> None:
super().__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
self.split = verify_str_arg(split, "split", ("train", "test"))
self.partition = partition
self._data_dir = Path(self.root) / "SUN397" self._data_dir = Path(self.root) / "SUN397"
if self.partition is not None:
if self.partition < 0 or self.partition > 10:
raise RuntimeError(f"The partition parameter should be an int in [1, 10] or None, got {partition}.")
if download: if download:
self._download() self._download()
...@@ -60,11 +46,7 @@ class SUN397(VisionDataset): ...@@ -60,11 +46,7 @@ class SUN397(VisionDataset):
self.classes = [c[3:].strip() for c in f] self.classes = [c[3:].strip() for c in f]
self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
if self.partition is not None: self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
with open(self._data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f:
self._image_files = [self._data_dir.joinpath(*line.strip()[1:].split("/")) for line in f]
else:
self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
self._labels = [ self._labels = [
self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
...@@ -86,13 +68,9 @@ class SUN397(VisionDataset): ...@@ -86,13 +68,9 @@ class SUN397(VisionDataset):
return image, label return image, label
def _check_exists(self) -> bool: def _check_exists(self) -> bool:
return self._data_dir.exists() and self._data_dir.is_dir() return self._data_dir.is_dir()
def extra_repr(self) -> str:
return "Split: {split}".format(**self.__dict__)
def _download(self) -> None: def _download(self) -> None:
if self._check_exists(): if self._check_exists():
return return
download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5) download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)
download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self._data_dir), md5=self._PARTITIONS_MD5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment