phototour.py 7.85 KB
Newer Older
edgarriba's avatar
edgarriba committed
1
import os
limm's avatar
limm committed
2
from pathlib import Path
3
from typing import Any, Callable, List, Optional, Tuple, Union
edgarriba's avatar
edgarriba committed
4

limm's avatar
limm committed
5
import numpy as np
edgarriba's avatar
edgarriba committed
6
import torch
limm's avatar
limm committed
7
from PIL import Image
edgarriba's avatar
edgarriba committed
8

9
from .utils import download_url
limm's avatar
limm committed
10
from .vision import VisionDataset
soumith's avatar
soumith committed
11

edgarriba's avatar
edgarriba committed
12

13
class PhotoTour(VisionDataset):
14
15
16
17
18
19
20
21
22
23
24
    """`Multi-view Stereo Correspondence <http://matthewalunbrown.com/patchdata/patchdata.html>`_ Dataset.

    .. note::

        We only provide the newer version of the dataset, since the authors state that it

            is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the
            patches are centred on real interest point detections, rather than being projections of 3D points as is the
            case in the old dataset.

        The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm.
25
26
27


    Args:
limm's avatar
limm committed
28
        root (str or ``pathlib.Path``): Root directory where images are.
29
        name (string): Name of the dataset to load.
limm's avatar
limm committed
30
        transform (callable, optional): A function/transform that takes in a PIL image
31
32
33
34
35
36
            and returns a transformed version.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
limm's avatar
limm committed
37

edgarriba's avatar
edgarriba committed
38
    urls = {
limm's avatar
limm committed
39
40
41
42
        "notredame_harris": [
            "http://matthewalunbrown.com/patchdata/notredame_harris.zip",
            "notredame_harris.zip",
            "69f8c90f78e171349abdf0307afefe4d",
43
        ],
limm's avatar
limm committed
44
45
46
47
        "yosemite_harris": [
            "http://matthewalunbrown.com/patchdata/yosemite_harris.zip",
            "yosemite_harris.zip",
            "a73253d1c6fbd3ba2613c45065c00d46",
48
        ],
limm's avatar
limm committed
49
50
51
52
        "liberty_harris": [
            "http://matthewalunbrown.com/patchdata/liberty_harris.zip",
            "liberty_harris.zip",
            "c731fcfb3abb4091110d0ae8c7ba182c",
soumith's avatar
soumith committed
53
        ],
limm's avatar
limm committed
54
55
56
57
        "notredame": [
            "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip",
            "notredame.zip",
            "509eda8535847b8c0a90bbb210c83484",
soumith's avatar
soumith committed
58
        ],
limm's avatar
limm committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        "yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"],
        "liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"],
    }
    means = {
        "notredame": 0.4854,
        "yosemite": 0.4844,
        "liberty": 0.4437,
        "notredame_harris": 0.4854,
        "yosemite_harris": 0.4844,
        "liberty_harris": 0.4437,
    }
    stds = {
        "notredame": 0.1864,
        "yosemite": 0.1818,
        "liberty": 0.2019,
        "notredame_harris": 0.1864,
        "yosemite_harris": 0.1818,
        "liberty_harris": 0.2019,
    }
    lens = {
        "notredame": 468159,
        "yosemite": 633587,
        "liberty": 450092,
        "liberty_harris": 379587,
        "yosemite_harris": 450912,
        "notredame_harris": 325295,
edgarriba's avatar
edgarriba committed
85
    }
limm's avatar
limm committed
86
87
88
    image_ext = "bmp"
    info_file = "info.txt"
    matches_files = "m50_100000_100000_0.txt"
edgarriba's avatar
edgarriba committed
89

90
    def __init__(
limm's avatar
limm committed
91
92
93
94
95
96
        self,
        root: Union[str, Path],
        name: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        download: bool = False,
97
    ) -> None:
limm's avatar
limm committed
98
        super().__init__(root, transform=transform)
edgarriba's avatar
edgarriba committed
99
        self.name = name
moskomule's avatar
moskomule committed
100
        self.data_dir = os.path.join(self.root, name)
limm's avatar
limm committed
101
102
        self.data_down = os.path.join(self.root, f"{name}.zip")
        self.data_file = os.path.join(self.root, f"{name}.pt")
edgarriba's avatar
edgarriba committed
103
104

        self.train = train
105
106
        self.mean = self.means[name]
        self.std = self.stds[name]
edgarriba's avatar
edgarriba committed
107
108
109
110

        if download:
            self.download()

soumith's avatar
soumith committed
111
        if not self._check_datafile_exists():
112
            self.cache()
edgarriba's avatar
edgarriba committed
113
114

        # load the serialized data
limm's avatar
limm committed
115
        self.data, self.labels, self.matches = torch.load(self.data_file, weights_only=True)
edgarriba's avatar
edgarriba committed
116

117
    def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
118
119
120
121
122
123
124
        """
        Args:
            index (int): Index

        Returns:
            tuple: (data1, data2, matches)
        """
edgarriba's avatar
edgarriba committed
125
126
127
128
129
130
131
132
133
134
135
136
        if self.train:
            data = self.data[index]
            if self.transform is not None:
                data = self.transform(data)
            return data
        m = self.matches[index]
        data1, data2 = self.data[m[0]], self.data[m[1]]
        if self.transform is not None:
            data1 = self.transform(data1)
            data2 = self.transform(data2)
        return data1, data2, m[2]

137
    def __len__(self) -> int:
138
        return len(self.data if self.train else self.matches)
edgarriba's avatar
edgarriba committed
139

140
    def _check_datafile_exists(self) -> bool:
edgarriba's avatar
edgarriba committed
141
142
        return os.path.exists(self.data_file)

143
    def _check_downloaded(self) -> bool:
edgarriba's avatar
edgarriba committed
144
145
        return os.path.exists(self.data_dir)

146
    def download(self) -> None:
soumith's avatar
soumith committed
147
        if self._check_datafile_exists():
limm's avatar
limm committed
148
            print(f"# Found cached data {self.data_file}")
edgarriba's avatar
edgarriba committed
149
150
151
152
            return

        if not self._check_downloaded():
            # download files
soumith's avatar
soumith committed
153
154
155
156
            url = self.urls[self.name][0]
            filename = self.urls[self.name][1]
            md5 = self.urls[self.name][2]
            fpath = os.path.join(self.root, filename)
edgarriba's avatar
edgarriba committed
157

soumith's avatar
soumith committed
158
            download_url(url, self.root, filename, md5)
edgarriba's avatar
edgarriba committed
159

limm's avatar
limm committed
160
            print(f"# Extracting data {self.data_down}\n")
edgarriba's avatar
edgarriba committed
161
162

            import zipfile
limm's avatar
limm committed
163
164

            with zipfile.ZipFile(fpath, "r") as z:
edgarriba's avatar
edgarriba committed
165
                z.extractall(self.data_dir)
soumith's avatar
soumith committed
166
167

            os.unlink(fpath)
edgarriba's avatar
edgarriba committed
168

169
    def cache(self) -> None:
edgarriba's avatar
edgarriba committed
170
        # process and save as torch files
limm's avatar
limm committed
171
        print(f"# Caching data {self.data_file}")
edgarriba's avatar
edgarriba committed
172

soumith's avatar
soumith committed
173
        dataset = (
edgarriba's avatar
edgarriba committed
174
175
            read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
            read_info_file(self.data_dir, self.info_file),
limm's avatar
limm committed
176
            read_matches_files(self.data_dir, self.matches_files),
edgarriba's avatar
edgarriba committed
177
178
        )

limm's avatar
limm committed
179
        with open(self.data_file, "wb") as f:
soumith's avatar
soumith committed
180
            torch.save(dataset, f)
edgarriba's avatar
edgarriba committed
181

182
    def extra_repr(self) -> str:
limm's avatar
limm committed
183
184
        split = "Train" if self.train is True else "Test"
        return f"Split: {split}"
185

edgarriba's avatar
edgarriba committed
186

187
def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
limm's avatar
limm committed
188
    """Return a Tensor containing the patches"""
189

190
    def PIL2array(_img: Image.Image) -> np.ndarray:
limm's avatar
limm committed
191
        """Convert PIL image type to numpy 2D array"""
edgarriba's avatar
edgarriba committed
192
193
        return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)

194
    def find_files(_data_dir: str, _image_ext: str) -> List[str]:
limm's avatar
limm committed
195
        """Return a list with the file names of the images containing the patches"""
edgarriba's avatar
edgarriba committed
196
197
198
199
200
201
202
203
204
205
        files = []
        # find those files with the specified extension
        for file_dir in os.listdir(_data_dir):
            if file_dir.endswith(_image_ext):
                files.append(os.path.join(_data_dir, file_dir))
        return sorted(files)  # sort files in ascend order to keep relations

    patches = []
    list_files = find_files(data_dir, image_ext)

soumith's avatar
soumith committed
206
207
    for fpath in list_files:
        img = Image.open(fpath)
208
209
        for y in range(0, img.height, 64):
            for x in range(0, img.width, 64):
edgarriba's avatar
edgarriba committed
210
211
212
213
214
                patch = img.crop((x, y, x + 64, y + 64))
                patches.append(PIL2array(patch))
    return torch.ByteTensor(np.array(patches[:n]))


215
def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
edgarriba's avatar
edgarriba committed
216
    """Return a Tensor containing the list of labels
limm's avatar
limm committed
217
    Read the file and keep only the ID of the 3D point.
edgarriba's avatar
edgarriba committed
218
    """
limm's avatar
limm committed
219
    with open(os.path.join(data_dir, info_file)) as f:
edgarriba's avatar
edgarriba committed
220
221
222
223
        labels = [int(line.split()[0]) for line in f]
    return torch.LongTensor(labels)


224
def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
edgarriba's avatar
edgarriba committed
225
    """Return a Tensor containing the ground truth matches
limm's avatar
limm committed
226
227
    Read the file and keep only 3D point ID.
    Matches are represented with a 1, non matches with a 0.
edgarriba's avatar
edgarriba committed
228
229
    """
    matches = []
limm's avatar
limm committed
230
    with open(os.path.join(data_dir, matches_file)) as f:
edgarriba's avatar
edgarriba committed
231
        for line in f:
232
            line_split = line.split()
limm's avatar
limm committed
233
            matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
edgarriba's avatar
edgarriba committed
234
    return torch.LongTensor(matches)