Unverified Commit e32b19e1 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Add support for Rendered sst2 dataset (#5220)



* Adding multiweight support for shufflenetv2 prototype models

* Revert "Adding multiweight support for shufflenetv2 prototype models"

This reverts commit 31fadbee7d1a65cd73ae43dfd4ac6e97e7ca7b01.

* Adding multiweight support for shufflenetv2 prototype models

* Revert "Adding multiweight support for shufflenetv2 prototype models"

This reverts commit 4e3d900f796c1e3e667312087e77956ca4a4c017.

* Add RenderedSST2 dataset

* Address PR comments

* Fix bug in dataset verification
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent f670152b
......@@ -70,6 +70,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
PCAM
PhotoTour
Places365
RenderedSST2
QMNIST
SBDataset
SBU
......
......@@ -2665,5 +2665,27 @@ class PCAMTestCase(datasets_utils.ImageDatasetTestCase):
return num_images
class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.RenderedSST2
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
SPLIT_TO_FOLDER = {"train": "train", "val": "valid", "test": "test"}
def inject_fake_data(self, tmpdir: str, config):
root_folder = pathlib.Path(tmpdir) / "rendered-sst2"
image_folder = root_folder / self.SPLIT_TO_FOLDER[config["split"]]
num_images_per_class = {"train": 5, "test": 6, "val": 7}
sampled_classes = ["positive", "negative"]
for cls in sampled_classes:
datasets_utils.create_image_folder(
image_folder,
cls,
file_name_fn=lambda idx: f"{idx}.png",
num_examples=num_images_per_class[config["split"]],
)
return len(sampled_classes) * num_images_per_class[config["split"]]
if __name__ == "__main__":
unittest.main()
......@@ -29,6 +29,7 @@ from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
from .phototour import PhotoTour
from .places365 import Places365
from .rendered_sst2 import RenderedSST2
from .sbd import SBDataset
from .sbu import SBU
from .semeion import SEMEION
......@@ -102,4 +103,5 @@ __all__ = (
"Country211",
"FGVCAircraft",
"EuroSAT",
"RenderedSST2",
)
from pathlib import Path
from typing import Any, Tuple, Callable, Optional
import PIL.Image
from .utils import verify_str_arg, download_and_extract_archive
from .vision import VisionDataset
class RenderedSST2(VisionDataset):
"""`The Rendered SST2 Dataset <https://github.com/openai/CLIP/blob/main/data/rendered-sst2.md>`_.
Rendered SST2 is an image classification dataset used to evaluate the models capability on optical
character recognition. This dataset was generated by rendering sentences in the Standford Sentiment
Treebank v2 dataset.
This dataset contains two classes (positive and negative) and is divided in three splits: a train
split containing 6920 images (3610 positive and 3310 negative), a validation split containing 872 images
(444 positive and 428 negative), and a test split containing 1821 images (909 positive and 912 negative).
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
_URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz"
_MD5 = "2384d08e9dcfa4bd55b324e610496ee5"
def __init__(
self,
root: str,
split: str = "train",
download: bool = False,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
self._split_to_folder = {"train": "train", "val": "valid", "test": "test"}
self._base_folder = Path(self.root) / "rendered-sst2"
self.classes = ["negative", "positive"]
self.class_to_idx = {"negative": 0, "positive": 1}
if download:
self._download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self._labels = []
self._image_files = []
for p in (self._base_folder / self._split_to_folder[self._split]).glob("**/*.png"):
self._labels.append(self.class_to_idx[p.parent.name])
self._image_files.append(p)
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
def extra_repr(self) -> str:
return f"split={self._split}"
def _check_exists(self) -> bool:
for class_label in set(self.classes):
if not (self._base_folder / self._split_to_folder[self._split] / class_label).is_dir():
return False
return True
def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment