Unverified Commit 97385df0 authored by Drishti Bhasin's avatar Drishti Bhasin Committed by GitHub
Browse files

add EuroSAT prototype dataset (#5452)



* add eurosat

* revert formatting

* port test and make style changes

* add eurosat to __init__

* fix pathlib error

* create dataset zipfile and revert pre commit changes

* remove unecessary variable in resources

* revert auto formatter changes and modify ufmt version

* revert change to contributing guide
Co-authored-by: default avatarDbhasin1 <drishti_b@me.iitr.c.in>
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 01f07eeb
...@@ -1327,6 +1327,24 @@ def cub200(info, root, config): ...@@ -1327,6 +1327,24 @@ def cub200(info, root, config):
return num_samples_map[config.split] return num_samples_map[config.split]
@register_mock
def eurosat(info, root, config):
data_folder = pathlib.Path(root, "eurosat", "2750")
data_folder.mkdir(parents=True)
num_examples_per_class = 3
classes = ("AnnualCrop", "Forest")
for cls in classes:
create_image_folder(
root=data_folder,
name=cls,
file_name_fn=lambda idx: f"{cls}_{idx}.jpg",
num_examples=num_examples_per_class,
)
make_zip(root, "EuroSAT.zip", data_folder)
return len(classes) * num_examples_per_class
@register_mock @register_mock
def svhn(info, root, config): def svhn(info, root, config):
import scipy.io as sio import scipy.io as sio
......
...@@ -6,6 +6,7 @@ from .coco import Coco ...@@ -6,6 +6,7 @@ from .coco import Coco
from .country211 import Country211 from .country211 import Country211
from .cub200 import CUB200 from .cub200 import CUB200
from .dtd import DTD from .dtd import DTD
from .eurosat import EuroSAT
from .fer2013 import FER2013 from .fer2013 import FER2013
from .gtsrb import GTSRB from .gtsrb import GTSRB
from .imagenet import ImageNet from .imagenet import ImageNet
......
import pathlib
from typing import Any, Dict, List, Tuple
from torchdata.datapipes.iter import IterDataPipe, Mapper
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling
from torchvision.prototype.features import EncodedImage, Label
class EuroSAT(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"eurosat",
homepage="https://github.com/phelber/eurosat",
categories=(
"AnnualCrop",
"Forest",
"HerbaceousVegetation",
"Highway",
"Industrial," "Pasture",
"PermanentCrop",
"Residential",
"River",
"SeaLake",
),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"https://madm.dfki.de/files/sentinel/EuroSAT.zip",
sha256="8ebea626349354c5328b142b96d0430e647051f26efc2dc974c843f25ecf70bd",
)
]
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).parent.name
return dict(
label=Label.from_category(category, categories=self.categories),
path=path,
image=EncodedImage.from_file(buffer),
)
def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._prepare_sample)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment