Commit b13bac51 authored by soumith's avatar soumith Committed by Soumith Chintala
Browse files

refactored phototour to use utils

parent 32460f52
...@@ -6,5 +6,5 @@ max-line-length = 120 ...@@ -6,5 +6,5 @@ max-line-length = 120
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
ignore = F401,F403 ignore = F401,E402,F403
exclude = venv exclude = venv
...@@ -5,6 +5,7 @@ from .cifar import CIFAR10, CIFAR100 ...@@ -5,6 +5,7 @@ from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10 from .stl10 import STL10
from .mnist import MNIST from .mnist import MNIST
from .svhn import SVHN from .svhn import SVHN
from .phototour import PhotoTour
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'ImageFolder',
......
from __future__ import print_function from __future__ import print_function
import torch.utils.data as data
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
...@@ -11,6 +10,7 @@ if sys.version_info[0] == 2: ...@@ -11,6 +10,7 @@ if sys.version_info[0] == 2:
else: else:
import pickle import pickle
import torch.utils.data as data
from .utils import download_url, check_integrity from .utils import download_url, check_integrity
......
...@@ -127,15 +127,3 @@ class LSUN(data.Dataset): ...@@ -127,15 +127,3 @@ class LSUN(data.Dataset):
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')' return self.__class__.__name__ + ' (' + self.db_path + ')'
if __name__ == '__main__':
# lsun = LSUNClass(db_path='/home/soumith/local/lsun/train/bedroom_train_lmdb')
# a = lsun[0]
lsun = LSUN(db_path='/home/soumith/local/lsun/train',
classes=['bedroom_train', 'church_outdoor_train'])
print(lsun.classes)
print(lsun.dbs)
a, t = lsun[len(lsun) - 1]
print(a)
print(t)
...@@ -6,12 +6,26 @@ from PIL import Image ...@@ -6,12 +6,26 @@ from PIL import Image
import torch import torch
import torch.utils.data as data import torch.utils.data as data
from .utils import download_url, check_integrity
class PhotoTour(data.Dataset): class PhotoTour(data.Dataset):
urls = { urls = {
'notredame': 'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip', 'notredame': [
'yosemite': 'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/yosemite.zip', 'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip',
'liberty': 'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/liberty.zip' 'notredame.zip',
'509eda8535847b8c0a90bbb210c83484'
],
'yosemite': [
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/yosemite.zip',
'yosemite.zip',
'533b2e8eb7ede31be40abc317b2fd4f0'
],
'liberty': [
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/liberty.zip',
'liberty.zip',
'fdd9152f138ea5ef2091746689176414'
],
} }
mean = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437} mean = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437}
std = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019} std = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019}
...@@ -37,7 +51,7 @@ class PhotoTour(data.Dataset): ...@@ -37,7 +51,7 @@ class PhotoTour(data.Dataset):
if download: if download:
self.download() self.download()
if not self._check_exists(): if not self._check_datafile_exists():
raise RuntimeError('Dataset not found.' + raise RuntimeError('Dataset not found.' +
' You can use download=True to download it') ' You can use download=True to download it')
...@@ -62,59 +76,45 @@ class PhotoTour(data.Dataset): ...@@ -62,59 +76,45 @@ class PhotoTour(data.Dataset):
return self.lens[self.name] return self.lens[self.name]
return len(self.matches) return len(self.matches)
def _check_exists(self): def _check_datafile_exists(self):
return os.path.exists(self.data_file) return os.path.exists(self.data_file)
def _check_downloaded(self): def _check_downloaded(self):
return os.path.exists(self.data_dir) return os.path.exists(self.data_dir)
def download(self): def download(self):
from six.moves import urllib if self._check_datafile_exists():
print('\n-- Loading PhotoTour dataset: {}\n'.format(self.name))
if self._check_exists():
print('# Found cached data {}'.format(self.data_file)) print('# Found cached data {}'.format(self.data_file))
return return
if not self._check_downloaded(): if not self._check_downloaded():
# download files # download files
url = self.urls[self.name] url = self.urls[self.name][0]
filename = url.rpartition('/')[2] filename = self.urls[self.name][1]
file_path = os.path.join(self.root, filename) md5 = self.urls[self.name][2]
fpath = os.path.join(self.root, filename)
try:
os.makedirs(self.root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
print('# Downloading {} into {}\n\nIt might take while.' download_url(url, self.root, filename, md5)
' Please grab yourself a coffee and relax.'
.format(url, file_path))
urllib.request.urlretrieve(url, file_path)
assert os.path.exists(file_path)
print('# Extracting data {}\n'.format(self.data_down)) print('# Extracting data {}\n'.format(self.data_down))
import zipfile import zipfile
with zipfile.ZipFile(file_path, 'r') as z: with zipfile.ZipFile(fpath, 'r') as z:
z.extractall(self.data_dir) z.extractall(self.data_dir)
os.unlink(file_path)
os.unlink(fpath)
# process and save as torch files # process and save as torch files
print('# Caching data {}'.format(self.data_file)) print('# Caching data {}'.format(self.data_file))
data_set = ( dataset = (
read_image_file(self.data_dir, self.image_ext, self.lens[self.name]), read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
read_info_file(self.data_dir, self.info_file), read_info_file(self.data_dir, self.info_file),
read_matches_files(self.data_dir, self.matches_files) read_matches_files(self.data_dir, self.matches_files)
) )
with open(self.data_file, 'wb') as f: with open(self.data_file, 'wb') as f:
torch.save(data_set, f) torch.save(dataset, f)
def read_image_file(data_dir, image_ext, n): def read_image_file(data_dir, image_ext, n):
...@@ -138,8 +138,8 @@ def read_image_file(data_dir, image_ext, n): ...@@ -138,8 +138,8 @@ def read_image_file(data_dir, image_ext, n):
patches = [] patches = []
list_files = find_files(data_dir, image_ext) list_files = find_files(data_dir, image_ext)
for file_path in list_files: for fpath in list_files:
img = Image.open(file_path) img = Image.open(fpath)
for y in range(0, 1024, 64): for y in range(0, 1024, 64):
for x in range(0, 1024, 64): for x in range(0, 1024, 64):
patch = img.crop((x, y, x + 64, y + 64)) patch = img.crop((x, y, x + 64, y + 64))
...@@ -168,14 +168,3 @@ def read_matches_files(data_dir, matches_file): ...@@ -168,14 +168,3 @@ def read_matches_files(data_dir, matches_file):
l = line.split() l = line.split()
matches.append([int(l[0]), int(l[3]), int(l[1] == l[4])]) matches.append([int(l[0]), int(l[3]), int(l[1] == l[4])])
return torch.LongTensor(matches) return torch.LongTensor(matches)
if __name__ == '__main__':
dataset = PhotoTour(root='/home/eriba/datasets/patches_dataset',
name='notredame',
download=True)
print('Loaded PhotoTour: {} with {} images.'
.format(dataset.name, len(dataset.data)))
assert len(dataset.data) == len(dataset.labels)
...@@ -26,7 +26,8 @@ class STL10(CIFAR10): ...@@ -26,7 +26,8 @@ class STL10(CIFAR10):
['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e'] ['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e']
] ]
def __init__(self, root, split='train', transform=None, target_transform=None, download=False): def __init__(self, root, split='train',
transform=None, target_transform=None, download=False):
self.root = root self.root = root
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
...@@ -37,7 +38,8 @@ class STL10(CIFAR10): ...@@ -37,7 +38,8 @@ class STL10(CIFAR10):
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError( raise RuntimeError(
'Dataset not found or corrupted. You can use download=True to download it') 'Dataset not found or corrupted. '
'You can use download=True to download it')
# now load the picked numpy arrays # now load the picked numpy arrays
if self.split == 'train': if self.split == 'train':
......
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import sys import sys
from .utils import download_url, check_integrity from .utils import download_url, check_integrity
class SVHN(data.Dataset): class SVHN(data.Dataset):
url = "" url = ""
filename = "" filename = ""
......
import os import os
import os.path
import hashlib import hashlib
import errno import errno
def check_integrity(fpath, md5): def check_integrity(fpath, md5):
if not os.path.isfile(fpath): if not os.path.isfile(fpath):
return False return False
md5o = hashlib.md5() md5o = hashlib.md5()
with open(fpath,'rb') as f: with open(fpath, 'rb') as f:
# read in 1MB chunks # read in 1MB chunks
for chunk in iter(lambda: f.read(1024 * 1024 * 1024), b''): for chunk in iter(lambda: f.read(1024 * 1024 * 1024), b''):
md5o.update(chunk) md5o.update(chunk)
...@@ -16,7 +18,7 @@ def check_integrity(fpath, md5): ...@@ -16,7 +18,7 @@ def check_integrity(fpath, md5):
return True return True
def download_url(url, root, filename, md5=None): def download_url(url, root, filename, md5):
from six.moves import urllib from six.moves import urllib
fpath = os.path.join(root, filename) fpath = os.path.join(root, filename)
......
...@@ -31,6 +31,7 @@ def inception_v3(pretrained=False, **kwargs): ...@@ -31,6 +31,7 @@ def inception_v3(pretrained=False, **kwargs):
class Inception3(nn.Module): class Inception3(nn.Module):
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
super(Inception3, self).__init__() super(Inception3, self).__init__()
self.aux_logits = aux_logits self.aux_logits = aux_logits
...@@ -126,6 +127,7 @@ class Inception3(nn.Module): ...@@ -126,6 +127,7 @@ class Inception3(nn.Module):
class InceptionA(nn.Module): class InceptionA(nn.Module):
def __init__(self, in_channels, pool_features): def __init__(self, in_channels, pool_features):
super(InceptionA, self).__init__() super(InceptionA, self).__init__()
self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)
...@@ -157,6 +159,7 @@ class InceptionA(nn.Module): ...@@ -157,6 +159,7 @@ class InceptionA(nn.Module):
class InceptionB(nn.Module): class InceptionB(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super(InceptionB, self).__init__() super(InceptionB, self).__init__()
self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)
...@@ -179,6 +182,7 @@ class InceptionB(nn.Module): ...@@ -179,6 +182,7 @@ class InceptionB(nn.Module):
class InceptionC(nn.Module): class InceptionC(nn.Module):
def __init__(self, in_channels, channels_7x7): def __init__(self, in_channels, channels_7x7):
super(InceptionC, self).__init__() super(InceptionC, self).__init__()
self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)
...@@ -217,6 +221,7 @@ class InceptionC(nn.Module): ...@@ -217,6 +221,7 @@ class InceptionC(nn.Module):
class InceptionD(nn.Module): class InceptionD(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super(InceptionD, self).__init__() super(InceptionD, self).__init__()
self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
...@@ -242,6 +247,7 @@ class InceptionD(nn.Module): ...@@ -242,6 +247,7 @@ class InceptionD(nn.Module):
class InceptionE(nn.Module): class InceptionE(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super(InceptionE, self).__init__() super(InceptionE, self).__init__()
self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)
...@@ -283,6 +289,7 @@ class InceptionE(nn.Module): ...@@ -283,6 +289,7 @@ class InceptionE(nn.Module):
class InceptionAux(nn.Module): class InceptionAux(nn.Module):
def __init__(self, in_channels, num_classes): def __init__(self, in_channels, num_classes):
super(InceptionAux, self).__init__() super(InceptionAux, self).__init__()
self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
...@@ -307,6 +314,7 @@ class InceptionAux(nn.Module): ...@@ -307,6 +314,7 @@ class InceptionAux(nn.Module):
class BasicConv2d(nn.Module): class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__() super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment