svhn.py 2.79 KB
Newer Older
1
2
3
4
5
6
7
8
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
soumith's avatar
soumith committed
9
from .utils import download_url, check_integrity
10

soumith's avatar
soumith committed
11

12
13
14
15
16
17
18
19
20
21
22
23
24
class SVHN(data.Dataset):
    url = ""
    filename = ""
    file_md5 = ""

    split_list = {
        'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
                  "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
        'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
                 "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
        'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
                  "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}

soumith's avatar
soumith committed
25
26
    def __init__(self, root, split='train',
                 transform=None, target_transform=None, download=False):
27
28
29
30
31
32
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.split = split  # training set or test set or extra set

        if self.split not in self.split_list:
soumith's avatar
soumith committed
33
34
            raise ValueError('Wrong split entered! Please use split="train" '
                             'or split="extra" or split="test"')
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

        self.url = self.split_list[split][0]
        self.filename = self.split_list[split][1]
        self.file_md5 = self.split_list[split][2]

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        # import here rather than at top of file because this is
        # an optional dependency for torchvision
        import scipy.io as sio

        # reading(loading) mat file as array
        loaded_mat = sio.loadmat(os.path.join(root, self.filename))

        self.data = loaded_mat['X']
        self.labels = loaded_mat['y']
        self.data = np.transpose(self.data, (3, 2, 0, 1))

    def __getitem__(self, index):
        img, target = self.data[index], self.labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(np.transpose(img, (1, 2, 0)))

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)

    def _check_integrity(self):
        root = self.root
        md5 = self.split_list[self.split][2]
        fpath = os.path.join(root, self.filename)
soumith's avatar
soumith committed
80
        return check_integrity(fpath, md5)
81
82

    def download(self):
soumith's avatar
soumith committed
83
84
        md5 = self.split_list[self.split][2]
        download_url(self.url, self.root, self.filename, md5)