phototour.py 8 KB
Newer Older
edgarriba's avatar
edgarriba committed
1
2
3
import os
import numpy as np
from PIL import Image
4
from typing import Any, Callable, List, Optional, Tuple, Union
edgarriba's avatar
edgarriba committed
5
6

import torch
7
from .vision import VisionDataset
edgarriba's avatar
edgarriba committed
8

9
from .utils import download_url
soumith's avatar
soumith committed
10

edgarriba's avatar
edgarriba committed
11

12
class PhotoTour(VisionDataset):
13
14
15
16
17
18
19
20
21
22
23
    """`Multi-view Stereo Correspondence <http://matthewalunbrown.com/patchdata/patchdata.html>`_ Dataset.

    .. note::

        We only provide the newer version of the dataset, since the authors state that it

            is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the
            patches are centred on real interest point detections, rather than being projections of 3D points as is the
            case in the old dataset.

        The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm.
24
25
26
27
28
29
30
31
32
33
34
35


    Args:
        root (string): Root directory where images are.
        name (string): Name of the dataset to load.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
edgarriba's avatar
edgarriba committed
36
    urls = {
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        'notredame_harris': [
            'http://matthewalunbrown.com/patchdata/notredame_harris.zip',
            'notredame_harris.zip',
            '69f8c90f78e171349abdf0307afefe4d'
        ],
        'yosemite_harris': [
            'http://matthewalunbrown.com/patchdata/yosemite_harris.zip',
            'yosemite_harris.zip',
            'a73253d1c6fbd3ba2613c45065c00d46'
        ],
        'liberty_harris': [
            'http://matthewalunbrown.com/patchdata/liberty_harris.zip',
            'liberty_harris.zip',
            'c731fcfb3abb4091110d0ae8c7ba182c'
        ],
soumith's avatar
soumith committed
52
        'notredame': [
53
            'http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip',
soumith's avatar
soumith committed
54
55
56
57
            'notredame.zip',
            '509eda8535847b8c0a90bbb210c83484'
        ],
        'yosemite': [
58
            'http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip',
soumith's avatar
soumith committed
59
60
61
62
            'yosemite.zip',
            '533b2e8eb7ede31be40abc317b2fd4f0'
        ],
        'liberty': [
63
            'http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip',
soumith's avatar
soumith committed
64
65
66
            'liberty.zip',
            'fdd9152f138ea5ef2091746689176414'
        ],
edgarriba's avatar
edgarriba committed
67
    }
68
69
70
71
    means = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437,
             'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437}
    stds = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019,
            'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019}
72
73
    lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092,
            'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295}
edgarriba's avatar
edgarriba committed
74
75
76
77
    image_ext = 'bmp'
    info_file = 'info.txt'
    matches_files = 'm50_100000_100000_0.txt'

78
79
80
    def __init__(
            self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
    ) -> None:
81
        super(PhotoTour, self).__init__(root, transform=transform)
edgarriba's avatar
edgarriba committed
82
        self.name = name
moskomule's avatar
moskomule committed
83
84
85
        self.data_dir = os.path.join(self.root, name)
        self.data_down = os.path.join(self.root, '{}.zip'.format(name))
        self.data_file = os.path.join(self.root, '{}.pt'.format(name))
edgarriba's avatar
edgarriba committed
86
87

        self.train = train
88
89
        self.mean = self.means[name]
        self.std = self.stds[name]
edgarriba's avatar
edgarriba committed
90
91
92
93

        if download:
            self.download()

soumith's avatar
soumith committed
94
        if not self._check_datafile_exists():
95
96
97
98
            try:
                self.cache()
            except Exception as error:
                raise RuntimeError("Dataset not found. You can use download=True to download it") from error
edgarriba's avatar
edgarriba committed
99
100
101
102

        # load the serialized data
        self.data, self.labels, self.matches = torch.load(self.data_file)

103
    def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
104
105
106
107
108
109
110
        """
        Args:
            index (int): Index

        Returns:
            tuple: (data1, data2, matches)
        """
edgarriba's avatar
edgarriba committed
111
112
113
114
115
116
117
118
119
120
121
122
        if self.train:
            data = self.data[index]
            if self.transform is not None:
                data = self.transform(data)
            return data
        m = self.matches[index]
        data1, data2 = self.data[m[0]], self.data[m[1]]
        if self.transform is not None:
            data1 = self.transform(data1)
            data2 = self.transform(data2)
        return data1, data2, m[2]

123
    def __len__(self) -> int:
edgarriba's avatar
edgarriba committed
124
125
126
127
        if self.train:
            return self.lens[self.name]
        return len(self.matches)

128
    def _check_datafile_exists(self) -> bool:
edgarriba's avatar
edgarriba committed
129
130
        return os.path.exists(self.data_file)

131
    def _check_downloaded(self) -> bool:
edgarriba's avatar
edgarriba committed
132
133
        return os.path.exists(self.data_dir)

134
    def download(self) -> None:
soumith's avatar
soumith committed
135
        if self._check_datafile_exists():
edgarriba's avatar
edgarriba committed
136
137
138
139
140
            print('# Found cached data {}'.format(self.data_file))
            return

        if not self._check_downloaded():
            # download files
soumith's avatar
soumith committed
141
142
143
144
            url = self.urls[self.name][0]
            filename = self.urls[self.name][1]
            md5 = self.urls[self.name][2]
            fpath = os.path.join(self.root, filename)
edgarriba's avatar
edgarriba committed
145

soumith's avatar
soumith committed
146
            download_url(url, self.root, filename, md5)
edgarriba's avatar
edgarriba committed
147
148
149
150

            print('# Extracting data {}\n'.format(self.data_down))

            import zipfile
soumith's avatar
soumith committed
151
            with zipfile.ZipFile(fpath, 'r') as z:
edgarriba's avatar
edgarriba committed
152
                z.extractall(self.data_dir)
soumith's avatar
soumith committed
153
154

            os.unlink(fpath)
edgarriba's avatar
edgarriba committed
155

156
    def cache(self) -> None:
edgarriba's avatar
edgarriba committed
157
158
159
        # process and save as torch files
        print('# Caching data {}'.format(self.data_file))

soumith's avatar
soumith committed
160
        dataset = (
edgarriba's avatar
edgarriba committed
161
162
163
164
165
166
            read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
            read_info_file(self.data_dir, self.info_file),
            read_matches_files(self.data_dir, self.matches_files)
        )

        with open(self.data_file, 'wb') as f:
soumith's avatar
soumith committed
167
            torch.save(dataset, f)
edgarriba's avatar
edgarriba committed
168

169
    def extra_repr(self) -> str:
170
        return "Split: {}".format("Train" if self.train is True else "Test")
171

edgarriba's avatar
edgarriba committed
172

173
def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
edgarriba's avatar
edgarriba committed
174
175
    """Return a Tensor containing the patches
    """
176

177
    def PIL2array(_img: Image.Image) -> np.ndarray:
edgarriba's avatar
edgarriba committed
178
179
180
181
        """Convert PIL image type to numpy 2D array
        """
        return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)

182
    def find_files(_data_dir: str, _image_ext: str) -> List[str]:
edgarriba's avatar
edgarriba committed
183
184
185
186
187
188
189
190
191
192
193
194
        """Return a list with the file names of the images containing the patches
        """
        files = []
        # find those files with the specified extension
        for file_dir in os.listdir(_data_dir):
            if file_dir.endswith(_image_ext):
                files.append(os.path.join(_data_dir, file_dir))
        return sorted(files)  # sort files in ascend order to keep relations

    patches = []
    list_files = find_files(data_dir, image_ext)

soumith's avatar
soumith committed
195
196
    for fpath in list_files:
        img = Image.open(fpath)
edgarriba's avatar
edgarriba committed
197
198
199
200
201
202
203
        for y in range(0, 1024, 64):
            for x in range(0, 1024, 64):
                patch = img.crop((x, y, x + 64, y + 64))
                patches.append(PIL2array(patch))
    return torch.ByteTensor(np.array(patches[:n]))


204
def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
edgarriba's avatar
edgarriba committed
205
206
207
208
209
210
211
212
    """Return a Tensor containing the list of labels
       Read the file and keep only the ID of the 3D point.
    """
    with open(os.path.join(data_dir, info_file), 'r') as f:
        labels = [int(line.split()[0]) for line in f]
    return torch.LongTensor(labels)


213
def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
edgarriba's avatar
edgarriba committed
214
215
216
217
218
219
220
    """Return a Tensor containing the ground truth matches
       Read the file and keep only 3D point ID.
       Matches are represented with a 1, non matches with a 0.
    """
    matches = []
    with open(os.path.join(data_dir, matches_file), 'r') as f:
        for line in f:
221
            line_split = line.split()
222
223
            matches.append([int(line_split[0]), int(line_split[3]),
                            int(line_split[1] == line_split[4])])
edgarriba's avatar
edgarriba committed
224
    return torch.LongTensor(matches)