semeion.py 3.03 KB
Newer Older
neoglez's avatar
neoglez committed
1
2
import os
import os.path
3
from typing import Any, Callable, Optional, Tuple
4
5
6
7

import numpy as np
from PIL import Image

neoglez's avatar
neoglez committed
8
from .utils import download_url, check_integrity
9
from .vision import VisionDataset
neoglez's avatar
neoglez committed
10
11


12
class SEMEION(VisionDataset):
13
14
    r"""`SEMEION <http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit>`_ Dataset.

neoglez's avatar
neoglez committed
15
16
17
18
19
20
21
22
23
24
    Args:
        root (string): Root directory of dataset where directory
            ``semeion.py`` exists.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
25

neoglez's avatar
neoglez committed
26
27
28
    """
    url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data"
    filename = "semeion.data"
29
    md5_checksum = "cb545d371d2ce14ec121470795a77432"
neoglez's avatar
neoglez committed
30

31
    def __init__(
32
33
34
35
36
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = True,
37
    ) -> None:
38
        super().__init__(root, transform=transform, target_transform=target_transform)
neoglez's avatar
neoglez committed
39
40
41
42
43

        if download:
            self.download()

        if not self._check_integrity():
44
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
neoglez's avatar
neoglez committed
45

46
        fp = os.path.join(self.root, self.filename)
47
48
49
        data = np.loadtxt(fp)
        # convert value to 8 bit unsigned integer
        # color (white #255) the pixels
50
        self.data = (data[:, :256] * 255).astype("uint8")
51
52
        self.data = np.reshape(self.data, (-1, 16, 16))
        self.labels = np.nonzero(data[:, 256:])[1]
neoglez's avatar
neoglez committed
53

54
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
neoglez's avatar
neoglez committed
55
56
57
        """
        Args:
            index (int): Index
58

neoglez's avatar
neoglez committed
59
60
61
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
62
        img, target = self.data[index], int(self.labels[index])
neoglez's avatar
neoglez committed
63
64
65

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
66
        img = Image.fromarray(img, mode="L")
neoglez's avatar
neoglez committed
67
68
69
70
71
72
73
74
75

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

76
    def __len__(self) -> int:
neoglez's avatar
neoglez committed
77
78
        return len(self.data)

79
    def _check_integrity(self) -> bool:
neoglez's avatar
neoglez committed
80
81
82
83
84
85
        root = self.root
        fpath = os.path.join(root, self.filename)
        if not check_integrity(fpath, self.md5_checksum):
            return False
        return True

86
    def download(self) -> None:
neoglez's avatar
neoglez committed
87
        if self._check_integrity():
88
            print("Files already downloaded and verified")
neoglez's avatar
neoglez committed
89
90
91
92
            return

        root = self.root
        download_url(self.url, root, self.filename, self.md5_checksum)