eurosat.py 2 KB
Newer Older
1
import os
2
from typing import Callable, Optional
3
4
5
6
7
8
9
10
11
12

from .folder import ImageFolder
from .utils import download_and_extract_archive


class EuroSAT(ImageFolder):
    """RGB version of the `EuroSAT <https://github.com/phelber/eurosat>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``root/eurosat`` exists.
anthony-cabacungan's avatar
anthony-cabacungan committed
13
        transform (callable, optional): A function/transform that takes in a PIL image
14
15
16
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
17
18
19
        download (bool, optional): If True, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again. Default is False.
20
21
22
23
24
    """

    def __init__(
        self,
        root: str,
25
26
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
27
28
29
30
31
32
33
34
35
36
37
38
        download: bool = False,
    ) -> None:
        self.root = os.path.expanduser(root)
        self._base_folder = os.path.join(self.root, "eurosat")
        self._data_folder = os.path.join(self._base_folder, "2750")

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError("Dataset not found. You can use download=True to download it")

39
        super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        self.root = os.path.expanduser(root)

    def __len__(self) -> int:
        return len(self.samples)

    def _check_exists(self) -> bool:
        return os.path.exists(self._data_folder)

    def download(self) -> None:

        if self._check_exists():
            return

        os.makedirs(self._base_folder, exist_ok=True)
54
55
56
57
58
        download_and_extract_archive(
            "https://madm.dfki.de/files/sentinel/EuroSAT.zip",
            download_root=self._base_folder,
            md5="c8fa014336c82ac7804f0398fcb19387",
        )