usps.py 3.42 KB
Newer Older
Francisco Massa's avatar
Francisco Massa committed
1
import os
limm's avatar
limm committed
2
3
4
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

Francisco Massa's avatar
Francisco Massa committed
5
import numpy as np
limm's avatar
limm committed
6
from PIL import Image
Francisco Massa's avatar
Francisco Massa committed
7
8
9
10
11
12
13

from .utils import download_url
from .vision import VisionDataset


class USPS(VisionDataset):
    """`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.
youkaichao's avatar
youkaichao committed
14
    The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``.
Francisco Massa's avatar
Francisco Massa committed
15
16
17
18
    The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
    and make pixel values in ``[0, 255]``.

    Args:
limm's avatar
limm committed
19
        root (str or ``pathlib.Path``): Root directory of dataset to store``USPS`` data files.
Francisco Massa's avatar
Francisco Massa committed
20
21
        train (bool, optional): If True, creates dataset from ``usps.bz2``,
            otherwise from ``usps.t.bz2``.
limm's avatar
limm committed
22
        transform (callable, optional): A function/transform that takes in a PIL image
Francisco Massa's avatar
Francisco Massa committed
23
24
25
26
27
28
29
30
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
limm's avatar
limm committed
31

Francisco Massa's avatar
Francisco Massa committed
32
    split_list = {
limm's avatar
limm committed
33
        "train": [
Francisco Massa's avatar
Francisco Massa committed
34
            "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
limm's avatar
limm committed
35
36
            "usps.bz2",
            "ec16c51db3855ca6c91edd34d0e9b197",
Francisco Massa's avatar
Francisco Massa committed
37
        ],
limm's avatar
limm committed
38
        "test": [
Francisco Massa's avatar
Francisco Massa committed
39
            "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
limm's avatar
limm committed
40
41
            "usps.t.bz2",
            "8ea070ee2aca1ac39742fdd1ef5ed118",
Francisco Massa's avatar
Francisco Massa committed
42
43
44
        ],
    }

45
    def __init__(
limm's avatar
limm committed
46
47
48
49
50
51
        self,
        root: Union[str, Path],
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
52
    ) -> None:
limm's avatar
limm committed
53
54
        super().__init__(root, transform=transform, target_transform=target_transform)
        split = "train" if train else "test"
Francisco Massa's avatar
Francisco Massa committed
55
56
57
58
59
60
61
        url, filename, checksum = self.split_list[split]
        full_path = os.path.join(self.root, filename)

        if download and not os.path.exists(full_path):
            download_url(url, self.root, filename, md5=checksum)

        import bz2
limm's avatar
limm committed
62

Francisco Massa's avatar
Francisco Massa committed
63
        with bz2.open(full_path) as fp:
Francisco Massa's avatar
Francisco Massa committed
64
            raw_data = [line.decode().split() for line in fp.readlines()]
limm's avatar
limm committed
65
            tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data]
Vasilis Vryniotis's avatar
Vasilis Vryniotis committed
66
            imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16))
limm's avatar
limm committed
67
            imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
Francisco Massa's avatar
Francisco Massa committed
68
69
70
71
72
            targets = [int(d[0]) - 1 for d in raw_data]

        self.data = imgs
        self.targets = targets

73
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
Francisco Massa's avatar
Francisco Massa committed
74
75
76
77
78
79
80
81
82
83
84
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
limm's avatar
limm committed
85
        img = Image.fromarray(img, mode="L")
Francisco Massa's avatar
Francisco Massa committed
86
87
88
89
90
91
92
93
94

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

95
    def __len__(self) -> int:
Francisco Massa's avatar
Francisco Massa committed
96
        return len(self.data)