Unverified Commit ffd0d237 authored by puhuk's avatar puhuk Committed by GitHub
Browse files

Add Country dataset (#5138)



* Add Country211 dataset

To addresses issue #5108.

* Add Country211 dataset

To addresses issue #5108.

* Update country211.py

* Update country211.py

* Code review reflected

Reflect code review

* Update test_datasets.py

* Update with review

Update with review

* inherit from ImageFolder

* Update test/test_datasets.py

* Docstring + minor test update
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent cc0d1beb
...@@ -38,6 +38,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas ...@@ -38,6 +38,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Cityscapes Cityscapes
CocoCaptions CocoCaptions
CocoDetection CocoDetection
Country211
DTD DTD
EMNIST EMNIST
FakeData FakeData
......
...@@ -2463,5 +2463,32 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2463,5 +2463,32 @@ class OxfordIIITPetTestCase(datasets_utils.ImageDatasetTestCase):
return (image_id, class_id, species, breed_id) return (image_id, class_id, species, breed_id)
class Country211TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Country211
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "valid", "test"))
def inject_fake_data(self, tmpdir: str, config):
split_folder = pathlib.Path(tmpdir) / "country211" / config["split"]
split_folder.mkdir(parents=True, exist_ok=True)
num_examples = {
"train": 3,
"valid": 4,
"test": 5,
}[config["split"]]
classes = ("AD", "BS", "GR")
for cls in classes:
datasets_utils.create_image_folder(
split_folder,
name=cls,
file_name_fn=lambda idx: f"{idx}.jpg",
num_examples=num_examples,
)
return num_examples * len(classes)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -5,6 +5,7 @@ from .cifar import CIFAR10, CIFAR100 ...@@ -5,6 +5,7 @@ from .cifar import CIFAR10, CIFAR100
from .cityscapes import Cityscapes from .cityscapes import Cityscapes
from .clevr import CLEVRClassification from .clevr import CLEVRClassification
from .coco import CocoCaptions, CocoDetection from .coco import CocoCaptions, CocoDetection
from .country211 import Country211
from .dtd import DTD from .dtd import DTD
from .fakedata import FakeData from .fakedata import FakeData
from .fer2013 import FER2013 from .fer2013 import FER2013
...@@ -91,4 +92,5 @@ __all__ = ( ...@@ -91,4 +92,5 @@ __all__ = (
"GTSRB", "GTSRB",
"CLEVRClassification", "CLEVRClassification",
"OxfordIIITPet", "OxfordIIITPet",
"Country211",
) )
from pathlib import Path
from typing import Callable, Optional
from .folder import ImageFolder
from .utils import verify_str_arg, download_and_extract_archive
class Country211(ImageFolder):
"""`The Country211 Data Set <https://github.com/openai/CLIP/blob/main/data/country211.md>`_ from OpenAI.
This dataset was built by filtering the images from the YFCC100m dataset
that have GPS coordinate corresponding to a ISO-3166 country code. The
dataset is balanced by sampling 150 train images, 50 validation images, and
100 test images images for each country.
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and puts it into
``root/country211/``. If dataset is already downloaded, it is not downloaded again.
"""
_URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
_MD5 = "84988d7644798601126c29e9877aab6a"
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
root = Path(root).expanduser()
self.root = str(root)
self._base_folder = root / "country211"
if download:
self._download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
self.root = str(root)
def _check_exists(self) -> bool:
return self._base_folder.exists() and self._base_folder.is_dir()
def _download(self) -> None:
if self._check_exists():
return
download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment