"packaging/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "a22b1e32a4912a518df11fe62912ff0247a2e779"
Unverified Commit 62e3fbd8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.phototour (#2531)

parent 1a6148d4
import os import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
from .vision import VisionDataset from .vision import VisionDataset
...@@ -54,17 +55,19 @@ class PhotoTour(VisionDataset): ...@@ -54,17 +55,19 @@ class PhotoTour(VisionDataset):
'fdd9152f138ea5ef2091746689176414' 'fdd9152f138ea5ef2091746689176414'
], ],
} }
mean = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437, means = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437,
'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437} 'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437}
std = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019, stds = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019,
'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019} 'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019}
lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092, lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092,
'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295} 'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295}
image_ext = 'bmp' image_ext = 'bmp'
info_file = 'info.txt' info_file = 'info.txt'
matches_files = 'm50_100000_100000_0.txt' matches_files = 'm50_100000_100000_0.txt'
def __init__(self, root, name, train=True, transform=None, download=False): def __init__(
self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
) -> None:
super(PhotoTour, self).__init__(root, transform=transform) super(PhotoTour, self).__init__(root, transform=transform)
self.name = name self.name = name
self.data_dir = os.path.join(self.root, name) self.data_dir = os.path.join(self.root, name)
...@@ -72,8 +75,8 @@ class PhotoTour(VisionDataset): ...@@ -72,8 +75,8 @@ class PhotoTour(VisionDataset):
self.data_file = os.path.join(self.root, '{}.pt'.format(name)) self.data_file = os.path.join(self.root, '{}.pt'.format(name))
self.train = train self.train = train
self.mean = self.mean[name] self.mean = self.means[name]
self.std = self.std[name] self.std = self.stds[name]
if download: if download:
self.download() self.download()
...@@ -85,7 +88,7 @@ class PhotoTour(VisionDataset): ...@@ -85,7 +88,7 @@ class PhotoTour(VisionDataset):
# load the serialized data # load the serialized data
self.data, self.labels, self.matches = torch.load(self.data_file) self.data, self.labels, self.matches = torch.load(self.data_file)
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -105,18 +108,18 @@ class PhotoTour(VisionDataset): ...@@ -105,18 +108,18 @@ class PhotoTour(VisionDataset):
data2 = self.transform(data2) data2 = self.transform(data2)
return data1, data2, m[2] return data1, data2, m[2]
def __len__(self): def __len__(self) -> int:
if self.train: if self.train:
return self.lens[self.name] return self.lens[self.name]
return len(self.matches) return len(self.matches)
def _check_datafile_exists(self): def _check_datafile_exists(self) -> bool:
return os.path.exists(self.data_file) return os.path.exists(self.data_file)
def _check_downloaded(self): def _check_downloaded(self) -> bool:
return os.path.exists(self.data_dir) return os.path.exists(self.data_dir)
def download(self): def download(self) -> None:
if self._check_datafile_exists(): if self._check_datafile_exists():
print('# Found cached data {}'.format(self.data_file)) print('# Found cached data {}'.format(self.data_file))
return return
...@@ -150,20 +153,20 @@ class PhotoTour(VisionDataset): ...@@ -150,20 +153,20 @@ class PhotoTour(VisionDataset):
with open(self.data_file, 'wb') as f: with open(self.data_file, 'wb') as f:
torch.save(dataset, f) torch.save(dataset, f)
def extra_repr(self): def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test") return "Split: {}".format("Train" if self.train is True else "Test")
def read_image_file(data_dir, image_ext, n): def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
"""Return a Tensor containing the patches """Return a Tensor containing the patches
""" """
def PIL2array(_img): def PIL2array(_img: Image.Image) -> np.ndarray:
"""Convert PIL image type to numpy 2D array """Convert PIL image type to numpy 2D array
""" """
return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64) return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)
def find_files(_data_dir, _image_ext): def find_files(_data_dir: str, _image_ext: str) -> List[str]:
"""Return a list with the file names of the images containing the patches """Return a list with the file names of the images containing the patches
""" """
files = [] files = []
...@@ -185,7 +188,7 @@ def read_image_file(data_dir, image_ext, n): ...@@ -185,7 +188,7 @@ def read_image_file(data_dir, image_ext, n):
return torch.ByteTensor(np.array(patches[:n])) return torch.ByteTensor(np.array(patches[:n]))
def read_info_file(data_dir, info_file): def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
"""Return a Tensor containing the list of labels """Return a Tensor containing the list of labels
Read the file and keep only the ID of the 3D point. Read the file and keep only the ID of the 3D point.
""" """
...@@ -195,7 +198,7 @@ def read_info_file(data_dir, info_file): ...@@ -195,7 +198,7 @@ def read_info_file(data_dir, info_file):
return torch.LongTensor(labels) return torch.LongTensor(labels)
def read_matches_files(data_dir, matches_file): def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
"""Return a Tensor containing the ground truth matches """Return a Tensor containing the ground truth matches
Read the file and keep only 3D point ID. Read the file and keep only 3D point ID.
Matches are represented with a 1, non matches with a 0. Matches are represented with a 1, non matches with a 0.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment