Unverified Commit 57a77c45 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Adds support for EuroSAT dataset (#5114)



* feat: Added EuroSAT dataset

* test: Added unittest

* docs: Improved comments

* docs: Updated the documentation

* docs: Removed unnecessary comments

* fix: Fixed class implementation

* test: Fixed unittest

* fix: Fixed magic method len

* test: Fixed unittest

* refactor: Refactored EuroSAT

* refactor: Applied modifications

* Apply suggestions from code review

* refactor: Applied request changes

* refactor: Made var explicit

* fix: Fixed attribute initialization order

* refactor: Removed name mapping
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 49468279
...@@ -41,6 +41,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas ...@@ -41,6 +41,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Country211 Country211
DTD DTD
EMNIST EMNIST
EuroSAT
FakeData FakeData
FashionMNIST FashionMNIST
FER2013 FER2013
......
...@@ -2169,6 +2169,27 @@ class HD1KTestCase(KittiFlowTestCase): ...@@ -2169,6 +2169,27 @@ class HD1KTestCase(KittiFlowTestCase):
return num_sequences * (num_examples_per_sequence - 1) return num_sequences * (num_examples_per_sequence - 1)
class EuroSATTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.EuroSAT
FEATURE_TYPES = (PIL.Image.Image, int)
def inject_fake_data(self, tmpdir, config):
data_folder = os.path.join(tmpdir, "eurosat", "2750")
os.makedirs(data_folder)
num_examples_per_class = 3
classes = ("AnnualCrop", "Forest")
for cls in classes:
datasets_utils.create_image_folder(
root=data_folder,
name=cls,
file_name_fn=lambda idx: f"{cls}_{idx}.jpg",
num_examples=num_examples_per_class,
)
return len(classes) * num_examples_per_class
class Food101TestCase(datasets_utils.ImageDatasetTestCase): class Food101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Food101 DATASET_CLASS = datasets.Food101
FEATURE_TYPES = (PIL.Image.Image, int) FEATURE_TYPES = (PIL.Image.Image, int)
......
...@@ -7,6 +7,7 @@ from .clevr import CLEVRClassification ...@@ -7,6 +7,7 @@ from .clevr import CLEVRClassification
from .coco import CocoCaptions, CocoDetection from .coco import CocoCaptions, CocoDetection
from .country211 import Country211 from .country211 import Country211
from .dtd import DTD from .dtd import DTD
from .eurosat import EuroSAT
from .fakedata import FakeData from .fakedata import FakeData
from .fer2013 import FER2013 from .fer2013 import FER2013
from .fgvc_aircraft import FGVCAircraft from .fgvc_aircraft import FGVCAircraft
...@@ -98,4 +99,5 @@ __all__ = ( ...@@ -98,4 +99,5 @@ __all__ = (
"OxfordIIITPet", "OxfordIIITPet",
"Country211", "Country211",
"FGVCAircraft", "FGVCAircraft",
"EuroSAT",
) )
import os
from typing import Any
from .folder import ImageFolder
from .utils import download_and_extract_archive
class EuroSAT(ImageFolder):
"""RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.
Args:
root (string): Root directory of dataset where ``root/eurosat`` exists.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip"
md5 = "c8fa014336c82ac7804f0398fcb19387"
def __init__(
self,
root: str,
download: bool = False,
**kwargs: Any,
) -> None:
self.root = os.path.expanduser(root)
self._base_folder = os.path.join(self.root, "eurosat")
self._data_folder = os.path.join(self._base_folder, "2750")
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
super().__init__(self._data_folder, **kwargs)
self.root = os.path.expanduser(root)
def __len__(self) -> int:
return len(self.samples)
def _check_exists(self) -> bool:
return os.path.exists(self._data_folder)
def download(self) -> None:
if self._check_exists():
return
os.makedirs(self._base_folder, exist_ok=True)
download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment