phototour.py 7.85 KB
Newer Older
edgarriba's avatar
edgarriba committed
1
import os
2
from pathlib import Path
3
from typing import Any, Callable, List, Optional, Tuple, Union
edgarriba's avatar
edgarriba committed
4

5
import numpy as np
edgarriba's avatar
edgarriba committed
6
import torch
7
from PIL import Image
edgarriba's avatar
edgarriba committed
8

9
from .utils import download_url
10
from .vision import VisionDataset
soumith's avatar
soumith committed
11

edgarriba's avatar
edgarriba committed
12

13
class PhotoTour(VisionDataset):
14
15
16
17
18
19
20
21
22
23
24
    """`Multi-view Stereo Correspondence <http://matthewalunbrown.com/patchdata/patchdata.html>`_ Dataset.

    .. note::

        We only provide the newer version of the dataset, since the authors state that it

            is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the
            patches are centred on real interest point detections, rather than being projections of 3D points as is the
            case in the old dataset.

        The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm.
25
26
27


    Args:
28
        root (str or ``pathlib.Path``): Root directory where images are.
29
        name (string): Name of the dataset to load.
anthony-cabacungan's avatar
anthony-cabacungan committed
30
        transform (callable, optional): A function/transform that takes in a PIL image
31
32
33
34
35
36
            and returns a transformed version.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
37

edgarriba's avatar
edgarriba committed
38
    urls = {
39
40
41
42
        "notredame_harris": [
            "http://matthewalunbrown.com/patchdata/notredame_harris.zip",
            "notredame_harris.zip",
            "69f8c90f78e171349abdf0307afefe4d",
43
        ],
44
45
46
47
        "yosemite_harris": [
            "http://matthewalunbrown.com/patchdata/yosemite_harris.zip",
            "yosemite_harris.zip",
            "a73253d1c6fbd3ba2613c45065c00d46",
48
        ],
49
50
51
52
        "liberty_harris": [
            "http://matthewalunbrown.com/patchdata/liberty_harris.zip",
            "liberty_harris.zip",
            "c731fcfb3abb4091110d0ae8c7ba182c",
soumith's avatar
soumith committed
53
        ],
54
55
56
57
        "notredame": [
            "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip",
            "notredame.zip",
            "509eda8535847b8c0a90bbb210c83484",
soumith's avatar
soumith committed
58
        ],
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        "yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"],
        "liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"],
    }
    means = {
        "notredame": 0.4854,
        "yosemite": 0.4844,
        "liberty": 0.4437,
        "notredame_harris": 0.4854,
        "yosemite_harris": 0.4844,
        "liberty_harris": 0.4437,
    }
    stds = {
        "notredame": 0.1864,
        "yosemite": 0.1818,
        "liberty": 0.2019,
        "notredame_harris": 0.1864,
        "yosemite_harris": 0.1818,
        "liberty_harris": 0.2019,
    }
    lens = {
        "notredame": 468159,
        "yosemite": 633587,
        "liberty": 450092,
        "liberty_harris": 379587,
        "yosemite_harris": 450912,
        "notredame_harris": 325295,
edgarriba's avatar
edgarriba committed
85
    }
86
87
88
    image_ext = "bmp"
    info_file = "info.txt"
    matches_files = "m50_100000_100000_0.txt"
edgarriba's avatar
edgarriba committed
89

90
    def __init__(
91
92
93
94
95
96
        self,
        root: Union[str, Path],
        name: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        download: bool = False,
97
    ) -> None:
98
        super().__init__(root, transform=transform)
edgarriba's avatar
edgarriba committed
99
        self.name = name
moskomule's avatar
moskomule committed
100
        self.data_dir = os.path.join(self.root, name)
101
102
        self.data_down = os.path.join(self.root, f"{name}.zip")
        self.data_file = os.path.join(self.root, f"{name}.pt")
edgarriba's avatar
edgarriba committed
103
104

        self.train = train
105
106
        self.mean = self.means[name]
        self.std = self.stds[name]
edgarriba's avatar
edgarriba committed
107
108
109
110

        if download:
            self.download()

soumith's avatar
soumith committed
111
        if not self._check_datafile_exists():
112
            self.cache()
edgarriba's avatar
edgarriba committed
113
114

        # load the serialized data
115
        self.data, self.labels, self.matches = torch.load(self.data_file, weights_only=True)
edgarriba's avatar
edgarriba committed
116

117
    def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
118
119
120
121
122
123
124
        """
        Args:
            index (int): Index

        Returns:
            tuple: (data1, data2, matches)
        """
edgarriba's avatar
edgarriba committed
125
126
127
128
129
130
131
132
133
134
135
136
        if self.train:
            data = self.data[index]
            if self.transform is not None:
                data = self.transform(data)
            return data
        m = self.matches[index]
        data1, data2 = self.data[m[0]], self.data[m[1]]
        if self.transform is not None:
            data1 = self.transform(data1)
            data2 = self.transform(data2)
        return data1, data2, m[2]

137
    def __len__(self) -> int:
138
        return len(self.data if self.train else self.matches)
edgarriba's avatar
edgarriba committed
139

140
    def _check_datafile_exists(self) -> bool:
edgarriba's avatar
edgarriba committed
141
142
        return os.path.exists(self.data_file)

143
    def _check_downloaded(self) -> bool:
edgarriba's avatar
edgarriba committed
144
145
        return os.path.exists(self.data_dir)

146
    def download(self) -> None:
soumith's avatar
soumith committed
147
        if self._check_datafile_exists():
148
            print(f"# Found cached data {self.data_file}")
edgarriba's avatar
edgarriba committed
149
150
151
152
            return

        if not self._check_downloaded():
            # download files
soumith's avatar
soumith committed
153
154
155
156
            url = self.urls[self.name][0]
            filename = self.urls[self.name][1]
            md5 = self.urls[self.name][2]
            fpath = os.path.join(self.root, filename)
edgarriba's avatar
edgarriba committed
157

soumith's avatar
soumith committed
158
            download_url(url, self.root, filename, md5)
edgarriba's avatar
edgarriba committed
159

160
            print(f"# Extracting data {self.data_down}\n")
edgarriba's avatar
edgarriba committed
161
162

            import zipfile
163
164

            with zipfile.ZipFile(fpath, "r") as z:
edgarriba's avatar
edgarriba committed
165
                z.extractall(self.data_dir)
soumith's avatar
soumith committed
166
167

            os.unlink(fpath)
edgarriba's avatar
edgarriba committed
168

169
    def cache(self) -> None:
edgarriba's avatar
edgarriba committed
170
        # process and save as torch files
171
        print(f"# Caching data {self.data_file}")
edgarriba's avatar
edgarriba committed
172

soumith's avatar
soumith committed
173
        dataset = (
edgarriba's avatar
edgarriba committed
174
175
            read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
            read_info_file(self.data_dir, self.info_file),
176
            read_matches_files(self.data_dir, self.matches_files),
edgarriba's avatar
edgarriba committed
177
178
        )

179
        with open(self.data_file, "wb") as f:
soumith's avatar
soumith committed
180
            torch.save(dataset, f)
edgarriba's avatar
edgarriba committed
181

182
    def extra_repr(self) -> str:
183
184
        split = "Train" if self.train is True else "Test"
        return f"Split: {split}"
185

edgarriba's avatar
edgarriba committed
186

187
def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
188
    """Return a Tensor containing the patches"""
189

190
    def PIL2array(_img: Image.Image) -> np.ndarray:
191
        """Convert PIL image type to numpy 2D array"""
edgarriba's avatar
edgarriba committed
192
193
        return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)

194
    def find_files(_data_dir: str, _image_ext: str) -> List[str]:
195
        """Return a list with the file names of the images containing the patches"""
edgarriba's avatar
edgarriba committed
196
197
198
199
200
201
202
203
204
205
        files = []
        # find those files with the specified extension
        for file_dir in os.listdir(_data_dir):
            if file_dir.endswith(_image_ext):
                files.append(os.path.join(_data_dir, file_dir))
        return sorted(files)  # sort files in ascend order to keep relations

    patches = []
    list_files = find_files(data_dir, image_ext)

soumith's avatar
soumith committed
206
207
    for fpath in list_files:
        img = Image.open(fpath)
208
209
        for y in range(0, img.height, 64):
            for x in range(0, img.width, 64):
edgarriba's avatar
edgarriba committed
210
211
212
213
214
                patch = img.crop((x, y, x + 64, y + 64))
                patches.append(PIL2array(patch))
    return torch.ByteTensor(np.array(patches[:n]))


215
def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
edgarriba's avatar
edgarriba committed
216
    """Return a Tensor containing the list of labels
217
    Read the file and keep only the ID of the 3D point.
edgarriba's avatar
edgarriba committed
218
    """
219
    with open(os.path.join(data_dir, info_file)) as f:
edgarriba's avatar
edgarriba committed
220
221
222
223
        labels = [int(line.split()[0]) for line in f]
    return torch.LongTensor(labels)


224
def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
edgarriba's avatar
edgarriba committed
225
    """Return a Tensor containing the ground truth matches
226
227
    Read the file and keep only 3D point ID.
    Matches are represented with a 1, non matches with a 0.
edgarriba's avatar
edgarriba committed
228
229
    """
    matches = []
230
    with open(os.path.join(data_dir, matches_file)) as f:
edgarriba's avatar
edgarriba committed
231
        for line in f:
232
            line_split = line.split()
233
            matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
edgarriba's avatar
edgarriba committed
234
    return torch.LongTensor(matches)