phototour.py 7.74 KB
Newer Older
edgarriba's avatar
edgarriba committed
1
import os
2
from typing import Any, Callable, List, Optional, Tuple, Union
edgarriba's avatar
edgarriba committed
3

4
import numpy as np
edgarriba's avatar
edgarriba committed
5
import torch
6
from PIL import Image
edgarriba's avatar
edgarriba committed
7

8
from .utils import download_url
9
from .vision import VisionDataset
soumith's avatar
soumith committed
10

edgarriba's avatar
edgarriba committed
11

12
class PhotoTour(VisionDataset):
13
14
15
16
17
18
19
20
21
22
23
    """`Multi-view Stereo Correspondence <http://matthewalunbrown.com/patchdata/patchdata.html>`_ Dataset.

    .. note::

        We only provide the newer version of the dataset, since the authors state that it

            is more suitable for training descriptors based on difference of Gaussian, or Harris corners, as the
            patches are centred on real interest point detections, rather than being projections of 3D points as is the
            case in the old dataset.

        The original dataset is available under http://phototour.cs.washington.edu/patches/default.htm.
24
25
26
27
28
29
30
31
32
33
34
35


    Args:
        root (string): Root directory where images are.
        name (string): Name of the dataset to load.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
36

edgarriba's avatar
edgarriba committed
37
    urls = {
38
39
40
41
        "notredame_harris": [
            "http://matthewalunbrown.com/patchdata/notredame_harris.zip",
            "notredame_harris.zip",
            "69f8c90f78e171349abdf0307afefe4d",
42
        ],
43
44
45
46
        "yosemite_harris": [
            "http://matthewalunbrown.com/patchdata/yosemite_harris.zip",
            "yosemite_harris.zip",
            "a73253d1c6fbd3ba2613c45065c00d46",
47
        ],
48
49
50
51
        "liberty_harris": [
            "http://matthewalunbrown.com/patchdata/liberty_harris.zip",
            "liberty_harris.zip",
            "c731fcfb3abb4091110d0ae8c7ba182c",
soumith's avatar
soumith committed
52
        ],
53
54
55
56
        "notredame": [
            "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip",
            "notredame.zip",
            "509eda8535847b8c0a90bbb210c83484",
soumith's avatar
soumith committed
57
        ],
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        "yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"],
        "liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"],
    }
    means = {
        "notredame": 0.4854,
        "yosemite": 0.4844,
        "liberty": 0.4437,
        "notredame_harris": 0.4854,
        "yosemite_harris": 0.4844,
        "liberty_harris": 0.4437,
    }
    stds = {
        "notredame": 0.1864,
        "yosemite": 0.1818,
        "liberty": 0.2019,
        "notredame_harris": 0.1864,
        "yosemite_harris": 0.1818,
        "liberty_harris": 0.2019,
    }
    lens = {
        "notredame": 468159,
        "yosemite": 633587,
        "liberty": 450092,
        "liberty_harris": 379587,
        "yosemite_harris": 450912,
        "notredame_harris": 325295,
edgarriba's avatar
edgarriba committed
84
    }
85
86
87
    image_ext = "bmp"
    info_file = "info.txt"
    matches_files = "m50_100000_100000_0.txt"
edgarriba's avatar
edgarriba committed
88

89
    def __init__(
90
        self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
91
    ) -> None:
92
        super().__init__(root, transform=transform)
edgarriba's avatar
edgarriba committed
93
        self.name = name
moskomule's avatar
moskomule committed
94
        self.data_dir = os.path.join(self.root, name)
95
96
        self.data_down = os.path.join(self.root, f"{name}.zip")
        self.data_file = os.path.join(self.root, f"{name}.pt")
edgarriba's avatar
edgarriba committed
97
98

        self.train = train
99
100
        self.mean = self.means[name]
        self.std = self.stds[name]
edgarriba's avatar
edgarriba committed
101
102
103
104

        if download:
            self.download()

soumith's avatar
soumith committed
105
        if not self._check_datafile_exists():
106
            self.cache()
edgarriba's avatar
edgarriba committed
107
108
109
110

        # load the serialized data
        self.data, self.labels, self.matches = torch.load(self.data_file)

111
    def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
112
113
114
115
116
117
118
        """
        Args:
            index (int): Index

        Returns:
            tuple: (data1, data2, matches)
        """
edgarriba's avatar
edgarriba committed
119
120
121
122
123
124
125
126
127
128
129
130
        if self.train:
            data = self.data[index]
            if self.transform is not None:
                data = self.transform(data)
            return data
        m = self.matches[index]
        data1, data2 = self.data[m[0]], self.data[m[1]]
        if self.transform is not None:
            data1 = self.transform(data1)
            data2 = self.transform(data2)
        return data1, data2, m[2]

131
    def __len__(self) -> int:
132
        return len(self.data if self.train else self.matches)
edgarriba's avatar
edgarriba committed
133

134
    def _check_datafile_exists(self) -> bool:
edgarriba's avatar
edgarriba committed
135
136
        return os.path.exists(self.data_file)

137
    def _check_downloaded(self) -> bool:
edgarriba's avatar
edgarriba committed
138
139
        return os.path.exists(self.data_dir)

140
    def download(self) -> None:
soumith's avatar
soumith committed
141
        if self._check_datafile_exists():
142
            print(f"# Found cached data {self.data_file}")
edgarriba's avatar
edgarriba committed
143
144
145
146
            return

        if not self._check_downloaded():
            # download files
soumith's avatar
soumith committed
147
148
149
150
            url = self.urls[self.name][0]
            filename = self.urls[self.name][1]
            md5 = self.urls[self.name][2]
            fpath = os.path.join(self.root, filename)
edgarriba's avatar
edgarriba committed
151

soumith's avatar
soumith committed
152
            download_url(url, self.root, filename, md5)
edgarriba's avatar
edgarriba committed
153

154
            print(f"# Extracting data {self.data_down}\n")
edgarriba's avatar
edgarriba committed
155
156

            import zipfile
157
158

            with zipfile.ZipFile(fpath, "r") as z:
edgarriba's avatar
edgarriba committed
159
                z.extractall(self.data_dir)
soumith's avatar
soumith committed
160
161

            os.unlink(fpath)
edgarriba's avatar
edgarriba committed
162

163
    def cache(self) -> None:
edgarriba's avatar
edgarriba committed
164
        # process and save as torch files
165
        print(f"# Caching data {self.data_file}")
edgarriba's avatar
edgarriba committed
166

soumith's avatar
soumith committed
167
        dataset = (
edgarriba's avatar
edgarriba committed
168
169
            read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
            read_info_file(self.data_dir, self.info_file),
170
            read_matches_files(self.data_dir, self.matches_files),
edgarriba's avatar
edgarriba committed
171
172
        )

173
        with open(self.data_file, "wb") as f:
soumith's avatar
soumith committed
174
            torch.save(dataset, f)
edgarriba's avatar
edgarriba committed
175

176
    def extra_repr(self) -> str:
177
178
        split = "Train" if self.train is True else "Test"
        return f"Split: {split}"
179

edgarriba's avatar
edgarriba committed
180

181
def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
182
    """Return a Tensor containing the patches"""
183

184
    def PIL2array(_img: Image.Image) -> np.ndarray:
185
        """Convert PIL image type to numpy 2D array"""
edgarriba's avatar
edgarriba committed
186
187
        return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)

188
    def find_files(_data_dir: str, _image_ext: str) -> List[str]:
189
        """Return a list with the file names of the images containing the patches"""
edgarriba's avatar
edgarriba committed
190
191
192
193
194
195
196
197
198
199
        files = []
        # find those files with the specified extension
        for file_dir in os.listdir(_data_dir):
            if file_dir.endswith(_image_ext):
                files.append(os.path.join(_data_dir, file_dir))
        return sorted(files)  # sort files in ascend order to keep relations

    patches = []
    list_files = find_files(data_dir, image_ext)

soumith's avatar
soumith committed
200
201
    for fpath in list_files:
        img = Image.open(fpath)
202
203
        for y in range(0, img.height, 64):
            for x in range(0, img.width, 64):
edgarriba's avatar
edgarriba committed
204
205
206
207
208
                patch = img.crop((x, y, x + 64, y + 64))
                patches.append(PIL2array(patch))
    return torch.ByteTensor(np.array(patches[:n]))


209
def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
edgarriba's avatar
edgarriba committed
210
    """Return a Tensor containing the list of labels
211
    Read the file and keep only the ID of the 3D point.
edgarriba's avatar
edgarriba committed
212
    """
213
    with open(os.path.join(data_dir, info_file)) as f:
edgarriba's avatar
edgarriba committed
214
215
216
217
        labels = [int(line.split()[0]) for line in f]
    return torch.LongTensor(labels)


218
def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
edgarriba's avatar
edgarriba committed
219
    """Return a Tensor containing the ground truth matches
220
221
    Read the file and keep only 3D point ID.
    Matches are represented with a 1, non matches with a 0.
edgarriba's avatar
edgarriba committed
222
223
    """
    matches = []
224
    with open(os.path.join(data_dir, matches_file)) as f:
edgarriba's avatar
edgarriba committed
225
        for line in f:
226
            line_split = line.split()
227
            matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
edgarriba's avatar
edgarriba committed
228
    return torch.LongTensor(matches)