omniglot.py 4.05 KB
Newer Older
Sanyam Kapoor's avatar
Sanyam Kapoor committed
1
from os.path import join
2
3
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
4
5
6

from PIL import Image

7
from .utils import check_integrity, download_and_extract_archive, list_dir, list_files
8
from .vision import VisionDataset
Sanyam Kapoor's avatar
Sanyam Kapoor committed
9
10


11
class Omniglot(VisionDataset):
Sanyam Kapoor's avatar
Sanyam Kapoor committed
12
    """`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
13

Sanyam Kapoor's avatar
Sanyam Kapoor committed
14
    Args:
15
        root (str or ``pathlib.Path``): Root directory of dataset where directory
Sanyam Kapoor's avatar
Sanyam Kapoor committed
16
17
18
            ``omniglot-py`` exists.
        background (bool, optional): If True, creates dataset from the "background" set, otherwise
            creates from the "evaluation" set. This terminology is defined by the authors.
anthony-cabacungan's avatar
anthony-cabacungan committed
19
        transform (callable, optional): A function/transform that takes in a PIL image
Sanyam Kapoor's avatar
Sanyam Kapoor committed
20
21
22
23
24
25
26
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset zip files from the internet and
            puts it in root directory. If the zip files are already downloaded, they are not
            downloaded again.
    """
27
28
29

    folder = "omniglot-py"
    download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
Sanyam Kapoor's avatar
Sanyam Kapoor committed
30
    zips_md5 = {
31
32
        "images_background": "68d2efa1b9178cc56df9314c21c6e718",
        "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
Sanyam Kapoor's avatar
Sanyam Kapoor committed
33
34
    }

35
    def __init__(
36
        self,
37
        root: Union[str, Path],
38
39
40
41
        background: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
42
    ) -> None:
43
        super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
44
        self.background = background
Sanyam Kapoor's avatar
Sanyam Kapoor committed
45
46
47
48
49

        if download:
            self.download()

        if not self._check_integrity():
50
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
Sanyam Kapoor's avatar
Sanyam Kapoor committed
51
52
53

        self.target_folder = join(self.root, self._get_target_folder())
        self._alphabets = list_dir(self.target_folder)
54
        self._characters: List[str] = sum(
55
            ([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
56
57
58
59
60
        )
        self._character_images = [
            [(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
            for idx, character in enumerate(self._characters)
        ]
61
        self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
Sanyam Kapoor's avatar
Sanyam Kapoor committed
62

63
    def __len__(self) -> int:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
64
65
        return len(self._flat_character_images)

66
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
67
68
69
70
71
72
73
74
75
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target character class.
        """
        image_name, character_class = self._flat_character_images[index]
        image_path = join(self.target_folder, self._characters[character_class], image_name)
76
        image = Image.open(image_path, mode="r").convert("L")
Sanyam Kapoor's avatar
Sanyam Kapoor committed
77
78
79
80
81
82
83
84
85

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            character_class = self.target_transform(character_class)

        return image, character_class

86
    def _check_integrity(self) -> bool:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
87
        zip_filename = self._get_target_folder()
88
        if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
Sanyam Kapoor's avatar
Sanyam Kapoor committed
89
90
91
            return False
        return True

92
    def download(self) -> None:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
93
        if self._check_integrity():
94
            print("Files already downloaded and verified")
Sanyam Kapoor's avatar
Sanyam Kapoor committed
95
96
97
            return

        filename = self._get_target_folder()
98
99
        zip_filename = filename + ".zip"
        url = self.download_url_prefix + "/" + zip_filename
100
        download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
Sanyam Kapoor's avatar
Sanyam Kapoor committed
101

102
    def _get_target_folder(self) -> str:
103
        return "images_background" if self.background else "images_evaluation"