Commit b13bac51 authored by soumith's avatar soumith Committed by Soumith Chintala
Browse files

refactored phototour to use utils

parent 32460f52
......@@ -6,5 +6,5 @@ max-line-length = 120
[flake8]
max-line-length = 120
ignore = F401,F403
ignore = F401,E402,F403
exclude = venv
......@@ -5,6 +5,7 @@ from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
from .mnist import MNIST
from .svhn import SVHN
from .phototour import PhotoTour
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder',
......
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
......@@ -11,6 +10,7 @@ if sys.version_info[0] == 2:
else:
import pickle
import torch.utils.data as data
from .utils import download_url, check_integrity
......
......@@ -127,15 +127,3 @@ class LSUN(data.Dataset):
def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'
if __name__ == '__main__':
# lsun = LSUNClass(db_path='/home/soumith/local/lsun/train/bedroom_train_lmdb')
# a = lsun[0]
lsun = LSUN(db_path='/home/soumith/local/lsun/train',
classes=['bedroom_train', 'church_outdoor_train'])
print(lsun.classes)
print(lsun.dbs)
a, t = lsun[len(lsun) - 1]
print(a)
print(t)
......@@ -6,12 +6,26 @@ from PIL import Image
import torch
import torch.utils.data as data
from .utils import download_url, check_integrity
class PhotoTour(data.Dataset):
urls = {
'notredame': 'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip',
'yosemite': 'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/yosemite.zip',
'liberty': 'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/liberty.zip'
'notredame': [
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip',
'notredame.zip',
'509eda8535847b8c0a90bbb210c83484'
],
'yosemite': [
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/yosemite.zip',
'yosemite.zip',
'533b2e8eb7ede31be40abc317b2fd4f0'
],
'liberty': [
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/liberty.zip',
'liberty.zip',
'fdd9152f138ea5ef2091746689176414'
],
}
mean = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437}
std = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019}
......@@ -37,7 +51,7 @@ class PhotoTour(data.Dataset):
if download:
self.download()
if not self._check_exists():
if not self._check_datafile_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
......@@ -62,59 +76,45 @@ class PhotoTour(data.Dataset):
return self.lens[self.name]
return len(self.matches)
def _check_exists(self):
def _check_datafile_exists(self):
return os.path.exists(self.data_file)
def _check_downloaded(self):
return os.path.exists(self.data_dir)
def download(self):
from six.moves import urllib
print('\n-- Loading PhotoTour dataset: {}\n'.format(self.name))
if self._check_exists():
if self._check_datafile_exists():
print('# Found cached data {}'.format(self.data_file))
return
if not self._check_downloaded():
# download files
url = self.urls[self.name]
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, filename)
try:
os.makedirs(self.root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
url = self.urls[self.name][0]
filename = self.urls[self.name][1]
md5 = self.urls[self.name][2]
fpath = os.path.join(self.root, filename)
print('# Downloading {} into {}\n\nIt might take while.'
' Please grab yourself a coffee and relax.'
.format(url, file_path))
urllib.request.urlretrieve(url, file_path)
assert os.path.exists(file_path)
download_url(url, self.root, filename, md5)
print('# Extracting data {}\n'.format(self.data_down))
import zipfile
with zipfile.ZipFile(file_path, 'r') as z:
with zipfile.ZipFile(fpath, 'r') as z:
z.extractall(self.data_dir)
os.unlink(file_path)
os.unlink(fpath)
# process and save as torch files
print('# Caching data {}'.format(self.data_file))
data_set = (
dataset = (
read_image_file(self.data_dir, self.image_ext, self.lens[self.name]),
read_info_file(self.data_dir, self.info_file),
read_matches_files(self.data_dir, self.matches_files)
)
with open(self.data_file, 'wb') as f:
torch.save(data_set, f)
torch.save(dataset, f)
def read_image_file(data_dir, image_ext, n):
......@@ -138,8 +138,8 @@ def read_image_file(data_dir, image_ext, n):
patches = []
list_files = find_files(data_dir, image_ext)
for file_path in list_files:
img = Image.open(file_path)
for fpath in list_files:
img = Image.open(fpath)
for y in range(0, 1024, 64):
for x in range(0, 1024, 64):
patch = img.crop((x, y, x + 64, y + 64))
......@@ -168,14 +168,3 @@ def read_matches_files(data_dir, matches_file):
l = line.split()
matches.append([int(l[0]), int(l[3]), int(l[1] == l[4])])
return torch.LongTensor(matches)
if __name__ == '__main__':
dataset = PhotoTour(root='/home/eriba/datasets/patches_dataset',
name='notredame',
download=True)
print('Loaded PhotoTour: {} with {} images.'
.format(dataset.name, len(dataset.data)))
assert len(dataset.data) == len(dataset.labels)
......@@ -26,7 +26,8 @@ class STL10(CIFAR10):
['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e']
]
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
def __init__(self, root, split='train',
transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
......@@ -37,7 +38,8 @@ class STL10(CIFAR10):
if not self._check_integrity():
raise RuntimeError(
'Dataset not found or corrupted. You can use download=True to download it')
'Dataset not found or corrupted. '
'You can use download=True to download it')
# now load the picked numpy arrays
if self.split == 'train':
......
......@@ -8,6 +8,7 @@ import numpy as np
import sys
from .utils import download_url, check_integrity
class SVHN(data.Dataset):
url = ""
filename = ""
......
import os
import os.path
import hashlib
import errno
def check_integrity(fpath, md5):
if not os.path.isfile(fpath):
return False
md5o = hashlib.md5()
with open(fpath,'rb') as f:
with open(fpath, 'rb') as f:
# read in 1MB chunks
for chunk in iter(lambda: f.read(1024 * 1024 * 1024), b''):
md5o.update(chunk)
......@@ -16,7 +18,7 @@ def check_integrity(fpath, md5):
return True
def download_url(url, root, filename, md5=None):
def download_url(url, root, filename, md5):
from six.moves import urllib
fpath = os.path.join(root, filename)
......
......@@ -31,6 +31,7 @@ def inception_v3(pretrained=False, **kwargs):
class Inception3(nn.Module):
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
super(Inception3, self).__init__()
self.aux_logits = aux_logits
......@@ -126,6 +127,7 @@ class Inception3(nn.Module):
class InceptionA(nn.Module):
def __init__(self, in_channels, pool_features):
super(InceptionA, self).__init__()
self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)
......@@ -157,6 +159,7 @@ class InceptionA(nn.Module):
class InceptionB(nn.Module):
def __init__(self, in_channels):
super(InceptionB, self).__init__()
self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)
......@@ -179,6 +182,7 @@ class InceptionB(nn.Module):
class InceptionC(nn.Module):
def __init__(self, in_channels, channels_7x7):
super(InceptionC, self).__init__()
self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)
......@@ -217,6 +221,7 @@ class InceptionC(nn.Module):
class InceptionD(nn.Module):
def __init__(self, in_channels):
super(InceptionD, self).__init__()
self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
......@@ -242,6 +247,7 @@ class InceptionD(nn.Module):
class InceptionE(nn.Module):
def __init__(self, in_channels):
super(InceptionE, self).__init__()
self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)
......@@ -283,6 +289,7 @@ class InceptionE(nn.Module):
class InceptionAux(nn.Module):
def __init__(self, in_channels, num_classes):
super(InceptionAux, self).__init__()
self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
......@@ -307,6 +314,7 @@ class InceptionAux(nn.Module):
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment