semeion.py 3.01 KB
Newer Older
neoglez's avatar
neoglez committed
1
import os.path
2
from typing import Any, Callable, Optional, Tuple
3
4
5
6

import numpy as np
from PIL import Image

7
from .utils import check_integrity, download_url
8
from .vision import VisionDataset
neoglez's avatar
neoglez committed
9
10


11
class SEMEION(VisionDataset):
12
13
    r"""`SEMEION <http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit>`_ Dataset.

neoglez's avatar
neoglez committed
14
15
16
    Args:
        root (string): Root directory of dataset where directory
            ``semeion.py`` exists.
anthony-cabacungan's avatar
anthony-cabacungan committed
17
        transform (callable, optional): A function/transform that takes in a PIL image
neoglez's avatar
neoglez committed
18
19
20
21
22
23
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
24

neoglez's avatar
neoglez committed
25
26
27
    """
    url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data"
    filename = "semeion.data"
28
    md5_checksum = "cb545d371d2ce14ec121470795a77432"
neoglez's avatar
neoglez committed
29

30
    def __init__(
31
32
33
34
35
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = True,
36
    ) -> None:
37
        super().__init__(root, transform=transform, target_transform=target_transform)
neoglez's avatar
neoglez committed
38
39
40
41
42

        if download:
            self.download()

        if not self._check_integrity():
43
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
neoglez's avatar
neoglez committed
44

45
        fp = os.path.join(self.root, self.filename)
46
47
48
        data = np.loadtxt(fp)
        # convert value to 8 bit unsigned integer
        # color (white #255) the pixels
49
        self.data = (data[:, :256] * 255).astype("uint8")
50
51
        self.data = np.reshape(self.data, (-1, 16, 16))
        self.labels = np.nonzero(data[:, 256:])[1]
neoglez's avatar
neoglez committed
52

53
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
neoglez's avatar
neoglez committed
54
55
56
        """
        Args:
            index (int): Index
57

neoglez's avatar
neoglez committed
58
59
60
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
61
        img, target = self.data[index], int(self.labels[index])
neoglez's avatar
neoglez committed
62
63
64

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
65
        img = Image.fromarray(img, mode="L")
neoglez's avatar
neoglez committed
66
67
68
69
70
71
72
73
74

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

75
    def __len__(self) -> int:
neoglez's avatar
neoglez committed
76
77
        return len(self.data)

78
    def _check_integrity(self) -> bool:
neoglez's avatar
neoglez committed
79
80
81
82
83
84
        root = self.root
        fpath = os.path.join(root, self.filename)
        if not check_integrity(fpath, self.md5_checksum):
            return False
        return True

85
    def download(self) -> None:
neoglez's avatar
neoglez committed
86
        if self._check_integrity():
87
            print("Files already downloaded and verified")
neoglez's avatar
neoglez committed
88
89
90
91
            return

        root = self.root
        download_url(self.url, root, self.filename, self.md5_checksum)