usps.py 3.36 KB
Newer Older
Francisco Massa's avatar
Francisco Massa committed
1
import os
2
from typing import Any, Callable, Optional, Tuple
Francisco Massa's avatar
Francisco Massa committed
3

4
5
6
import numpy as np
from PIL import Image

Francisco Massa's avatar
Francisco Massa committed
7
8
9
10
11
12
from .utils import download_url
from .vision import VisionDataset


class USPS(VisionDataset):
    """`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.
youkaichao's avatar
youkaichao committed
13
    The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``.
Francisco Massa's avatar
Francisco Massa committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
    and make pixel values in ``[0, 255]``.

    Args:
        root (string): Root directory of dataset to store``USPS`` data files.
        train (bool, optional): If True, creates dataset from ``usps.bz2``,
            otherwise from ``usps.t.bz2``.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
30

Francisco Massa's avatar
Francisco Massa committed
31
    split_list = {
32
        "train": [
Francisco Massa's avatar
Francisco Massa committed
33
            "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
34
35
            "usps.bz2",
            "ec16c51db3855ca6c91edd34d0e9b197",
Francisco Massa's avatar
Francisco Massa committed
36
        ],
37
        "test": [
Francisco Massa's avatar
Francisco Massa committed
38
            "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
39
40
            "usps.t.bz2",
            "8ea070ee2aca1ac39742fdd1ef5ed118",
Francisco Massa's avatar
Francisco Massa committed
41
42
43
        ],
    }

44
    def __init__(
45
46
47
48
49
50
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
51
    ) -> None:
52
        super().__init__(root, transform=transform, target_transform=target_transform)
53
        split = "train" if train else "test"
Francisco Massa's avatar
Francisco Massa committed
54
55
56
57
58
59
60
        url, filename, checksum = self.split_list[split]
        full_path = os.path.join(self.root, filename)

        if download and not os.path.exists(full_path):
            download_url(url, self.root, filename, md5=checksum)

        import bz2
61

Francisco Massa's avatar
Francisco Massa committed
62
        with bz2.open(full_path) as fp:
Francisco Massa's avatar
Francisco Massa committed
63
            raw_data = [line.decode().split() for line in fp.readlines()]
64
            tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
Vasilis Vryniotis's avatar
Vasilis Vryniotis committed
65
            imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
66
            imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
Francisco Massa's avatar
Francisco Massa committed
67
68
69
70
71
            targets = [int(d[0]) - 1 for d in raw_data]

        self.data = imgs
        self.targets = targets

72
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
Francisco Massa's avatar
Francisco Massa committed
73
74
75
76
77
78
79
80
81
82
83
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
84
        img = Image.fromarray(img, mode="L")
Francisco Massa's avatar
Francisco Massa committed
85
86
87
88
89
90
91
92
93

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

94
    def __len__(self) -> int:
Francisco Massa's avatar
Francisco Massa committed
95
        return len(self.data)