omniglot.py 4 KB
Newer Older
Sanyam Kapoor's avatar
Sanyam Kapoor committed
1
from os.path import join
2
from typing import Any, Callable, List, Optional, Tuple
3
4
5

from PIL import Image

6
from .utils import download_and_extract_archive, check_integrity, list_dir, list_files
7
from .vision import VisionDataset
Sanyam Kapoor's avatar
Sanyam Kapoor committed
8
9


10
class Omniglot(VisionDataset):
Sanyam Kapoor's avatar
Sanyam Kapoor committed
11
    """`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
12

Sanyam Kapoor's avatar
Sanyam Kapoor committed
13
14
15
16
17
18
19
20
21
22
23
24
25
    Args:
        root (string): Root directory of dataset where directory
            ``omniglot-py`` exists.
        background (bool, optional): If True, creates dataset from the "background" set, otherwise
            creates from the "evaluation" set. This terminology is defined by the authors.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset zip files from the internet and
            puts it in root directory. If the zip files are already downloaded, they are not
            downloaded again.
    """
26
27
28

    folder = "omniglot-py"
    download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python"
Sanyam Kapoor's avatar
Sanyam Kapoor committed
29
    zips_md5 = {
30
31
        "images_background": "68d2efa1b9178cc56df9314c21c6e718",
        "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811",
Sanyam Kapoor's avatar
Sanyam Kapoor committed
32
33
    }

34
    def __init__(
35
36
37
38
39
40
        self,
        root: str,
        background: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
41
    ) -> None:
42
        super().__init__(join(root, self.folder), transform=transform, target_transform=target_transform)
43
        self.background = background
Sanyam Kapoor's avatar
Sanyam Kapoor committed
44
45
46
47
48

        if download:
            self.download()

        if not self._check_integrity():
49
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
Sanyam Kapoor's avatar
Sanyam Kapoor committed
50
51
52

        self.target_folder = join(self.root, self._get_target_folder())
        self._alphabets = list_dir(self.target_folder)
53
        self._characters: List[str] = sum(
54
            ([join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets), []
55
56
57
58
59
        )
        self._character_images = [
            [(image, idx) for image in list_files(join(self.target_folder, character), ".png")]
            for idx, character in enumerate(self._characters)
        ]
60
        self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
Sanyam Kapoor's avatar
Sanyam Kapoor committed
61

62
    def __len__(self) -> int:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
63
64
        return len(self._flat_character_images)

65
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
66
67
68
69
70
71
72
73
74
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target character class.
        """
        image_name, character_class = self._flat_character_images[index]
        image_path = join(self.target_folder, self._characters[character_class], image_name)
75
        image = Image.open(image_path, mode="r").convert("L")
Sanyam Kapoor's avatar
Sanyam Kapoor committed
76
77
78
79
80
81
82
83
84

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            character_class = self.target_transform(character_class)

        return image, character_class

85
    def _check_integrity(self) -> bool:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
86
        zip_filename = self._get_target_folder()
87
        if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]):
Sanyam Kapoor's avatar
Sanyam Kapoor committed
88
89
90
            return False
        return True

91
    def download(self) -> None:
Sanyam Kapoor's avatar
Sanyam Kapoor committed
92
        if self._check_integrity():
93
            print("Files already downloaded and verified")
Sanyam Kapoor's avatar
Sanyam Kapoor committed
94
95
96
            return

        filename = self._get_target_folder()
97
98
        zip_filename = filename + ".zip"
        url = self.download_url_prefix + "/" + zip_filename
99
        download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
Sanyam Kapoor's avatar
Sanyam Kapoor committed
100

101
    def _get_target_folder(self) -> str:
102
        return "images_background" if self.background else "images_evaluation"