Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b13bac51
Commit
b13bac51
authored
Mar 17, 2017
by
soumith
Committed by
Soumith Chintala
Mar 17, 2017
Browse files
refactored phototour to use utils
parent
32460f52
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
52 additions
and
61 deletions
+52
-61
setup.cfg
setup.cfg
+1
-1
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+1
-0
torchvision/datasets/cifar.py
torchvision/datasets/cifar.py
+1
-1
torchvision/datasets/lsun.py
torchvision/datasets/lsun.py
+0
-12
torchvision/datasets/phototour.py
torchvision/datasets/phototour.py
+32
-43
torchvision/datasets/stl10.py
torchvision/datasets/stl10.py
+4
-2
torchvision/datasets/svhn.py
torchvision/datasets/svhn.py
+1
-0
torchvision/datasets/utils.py
torchvision/datasets/utils.py
+4
-2
torchvision/models/inception.py
torchvision/models/inception.py
+8
-0
No files found.
setup.cfg
View file @
b13bac51
...
...
@@ -6,5 +6,5 @@ max-line-length = 120
[flake8]
max-line-length = 120
ignore = F401,F403
ignore = F401,
E402,
F403
exclude = venv
torchvision/datasets/__init__.py
View file @
b13bac51
...
...
@@ -5,6 +5,7 @@ from .cifar import CIFAR10, CIFAR100
from
.stl10
import
STL10
from
.mnist
import
MNIST
from
.svhn
import
SVHN
from
.phototour
import
PhotoTour
__all__
=
(
'LSUN'
,
'LSUNClass'
,
'ImageFolder'
,
...
...
torchvision/datasets/cifar.py
View file @
b13bac51
from
__future__
import
print_function
import
torch.utils.data
as
data
from
PIL
import
Image
import
os
import
os.path
...
...
@@ -11,6 +10,7 @@ if sys.version_info[0] == 2:
else
:
import
pickle
import
torch.utils.data
as
data
from
.utils
import
download_url
,
check_integrity
...
...
torchvision/datasets/lsun.py
View file @
b13bac51
...
...
@@ -127,15 +127,3 @@ class LSUN(data.Dataset):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
' ('
+
self
.
db_path
+
')'
if
__name__
==
'__main__'
:
# lsun = LSUNClass(db_path='/home/soumith/local/lsun/train/bedroom_train_lmdb')
# a = lsun[0]
lsun
=
LSUN
(
db_path
=
'/home/soumith/local/lsun/train'
,
classes
=
[
'bedroom_train'
,
'church_outdoor_train'
])
print
(
lsun
.
classes
)
print
(
lsun
.
dbs
)
a
,
t
=
lsun
[
len
(
lsun
)
-
1
]
print
(
a
)
print
(
t
)
torchvision/datasets/phototour.py
View file @
b13bac51
...
...
@@ -6,12 +6,26 @@ from PIL import Image
import
torch
import
torch.utils.data
as
data
from
.utils
import
download_url
,
check_integrity
class
PhotoTour
(
data
.
Dataset
):
urls
=
{
'notredame'
:
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip'
,
'yosemite'
:
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/yosemite.zip'
,
'liberty'
:
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/liberty.zip'
'notredame'
:
[
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/notredame.zip'
,
'notredame.zip'
,
'509eda8535847b8c0a90bbb210c83484'
],
'yosemite'
:
[
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/yosemite.zip'
,
'yosemite.zip'
,
'533b2e8eb7ede31be40abc317b2fd4f0'
],
'liberty'
:
[
'http://www.iis.ee.ic.ac.uk/~vbalnt/phototourism-patches/liberty.zip'
,
'liberty.zip'
,
'fdd9152f138ea5ef2091746689176414'
],
}
mean
=
{
'notredame'
:
0.4854
,
'yosemite'
:
0.4844
,
'liberty'
:
0.4437
}
std
=
{
'notredame'
:
0.1864
,
'yosemite'
:
0.1818
,
'liberty'
:
0.2019
}
...
...
@@ -37,7 +51,7 @@ class PhotoTour(data.Dataset):
if
download
:
self
.
download
()
if
not
self
.
_check_exists
():
if
not
self
.
_check_
datafile_
exists
():
raise
RuntimeError
(
'Dataset not found.'
+
' You can use download=True to download it'
)
...
...
@@ -62,59 +76,45 @@ class PhotoTour(data.Dataset):
return
self
.
lens
[
self
.
name
]
return
len
(
self
.
matches
)
def
_check_exists
(
self
):
def
_check_
datafile_
exists
(
self
):
return
os
.
path
.
exists
(
self
.
data_file
)
def
_check_downloaded
(
self
):
return
os
.
path
.
exists
(
self
.
data_dir
)
def
download
(
self
):
from
six.moves
import
urllib
print
(
'
\n
-- Loading PhotoTour dataset: {}
\n
'
.
format
(
self
.
name
))
if
self
.
_check_exists
():
if
self
.
_check_datafile_exists
():
print
(
'# Found cached data {}'
.
format
(
self
.
data_file
))
return
if
not
self
.
_check_downloaded
():
# download files
url
=
self
.
urls
[
self
.
name
]
filename
=
url
.
rpartition
(
'/'
)[
2
]
file_path
=
os
.
path
.
join
(
self
.
root
,
filename
)
try
:
os
.
makedirs
(
self
.
root
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
:
pass
else
:
raise
url
=
self
.
urls
[
self
.
name
][
0
]
filename
=
self
.
urls
[
self
.
name
][
1
]
md5
=
self
.
urls
[
self
.
name
][
2
]
fpath
=
os
.
path
.
join
(
self
.
root
,
filename
)
print
(
'# Downloading {} into {}
\n\n
It might take while.'
' Please grab yourself a coffee and relax.'
.
format
(
url
,
file_path
))
urllib
.
request
.
urlretrieve
(
url
,
file_path
)
assert
os
.
path
.
exists
(
file_path
)
download_url
(
url
,
self
.
root
,
filename
,
md5
)
print
(
'# Extracting data {}
\n
'
.
format
(
self
.
data_down
))
import
zipfile
with
zipfile
.
ZipFile
(
f
ile_
path
,
'r'
)
as
z
:
with
zipfile
.
ZipFile
(
fpath
,
'r'
)
as
z
:
z
.
extractall
(
self
.
data_dir
)
os
.
unlink
(
file_path
)
os
.
unlink
(
fpath
)
# process and save as torch files
print
(
'# Caching data {}'
.
format
(
self
.
data_file
))
data
_
set
=
(
dataset
=
(
read_image_file
(
self
.
data_dir
,
self
.
image_ext
,
self
.
lens
[
self
.
name
]),
read_info_file
(
self
.
data_dir
,
self
.
info_file
),
read_matches_files
(
self
.
data_dir
,
self
.
matches_files
)
)
with
open
(
self
.
data_file
,
'wb'
)
as
f
:
torch
.
save
(
data
_
set
,
f
)
torch
.
save
(
dataset
,
f
)
def
read_image_file
(
data_dir
,
image_ext
,
n
):
...
...
@@ -138,8 +138,8 @@ def read_image_file(data_dir, image_ext, n):
patches
=
[]
list_files
=
find_files
(
data_dir
,
image_ext
)
for
f
ile_
path
in
list_files
:
img
=
Image
.
open
(
f
ile_
path
)
for
fpath
in
list_files
:
img
=
Image
.
open
(
fpath
)
for
y
in
range
(
0
,
1024
,
64
):
for
x
in
range
(
0
,
1024
,
64
):
patch
=
img
.
crop
((
x
,
y
,
x
+
64
,
y
+
64
))
...
...
@@ -168,14 +168,3 @@ def read_matches_files(data_dir, matches_file):
l
=
line
.
split
()
matches
.
append
([
int
(
l
[
0
]),
int
(
l
[
3
]),
int
(
l
[
1
]
==
l
[
4
])])
return
torch
.
LongTensor
(
matches
)
if
__name__
==
'__main__'
:
dataset
=
PhotoTour
(
root
=
'/home/eriba/datasets/patches_dataset'
,
name
=
'notredame'
,
download
=
True
)
print
(
'Loaded PhotoTour: {} with {} images.'
.
format
(
dataset
.
name
,
len
(
dataset
.
data
)))
assert
len
(
dataset
.
data
)
==
len
(
dataset
.
labels
)
torchvision/datasets/stl10.py
View file @
b13bac51
...
...
@@ -26,7 +26,8 @@ class STL10(CIFAR10):
[
'test_y.bin'
,
'36f9794fa4beb8a2c72628de14fa638e'
]
]
def
__init__
(
self
,
root
,
split
=
'train'
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
def
__init__
(
self
,
root
,
split
=
'train'
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
self
.
root
=
root
self
.
transform
=
transform
self
.
target_transform
=
target_transform
...
...
@@ -37,7 +38,8 @@ class STL10(CIFAR10):
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
'Dataset not found or corrupted. You can use download=True to download it'
)
'Dataset not found or corrupted. '
'You can use download=True to download it'
)
# now load the picked numpy arrays
if
self
.
split
==
'train'
:
...
...
torchvision/datasets/svhn.py
View file @
b13bac51
...
...
@@ -8,6 +8,7 @@ import numpy as np
import
sys
from
.utils
import
download_url
,
check_integrity
class
SVHN
(
data
.
Dataset
):
url
=
""
filename
=
""
...
...
torchvision/datasets/utils.py
View file @
b13bac51
import
os
import
os.path
import
hashlib
import
errno
def
check_integrity
(
fpath
,
md5
):
if
not
os
.
path
.
isfile
(
fpath
):
return
False
md5o
=
hashlib
.
md5
()
with
open
(
fpath
,
'rb'
)
as
f
:
with
open
(
fpath
,
'rb'
)
as
f
:
# read in 1MB chunks
for
chunk
in
iter
(
lambda
:
f
.
read
(
1024
*
1024
*
1024
),
b
''
):
md5o
.
update
(
chunk
)
...
...
@@ -16,7 +18,7 @@ def check_integrity(fpath, md5):
return
True
def
download_url
(
url
,
root
,
filename
,
md5
=
None
):
def
download_url
(
url
,
root
,
filename
,
md5
):
from
six.moves
import
urllib
fpath
=
os
.
path
.
join
(
root
,
filename
)
...
...
torchvision/models/inception.py
View file @
b13bac51
...
...
@@ -31,6 +31,7 @@ def inception_v3(pretrained=False, **kwargs):
class
Inception3
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
=
1000
,
aux_logits
=
True
,
transform_input
=
False
):
super
(
Inception3
,
self
).
__init__
()
self
.
aux_logits
=
aux_logits
...
...
@@ -126,6 +127,7 @@ class Inception3(nn.Module):
class
InceptionA
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
pool_features
):
super
(
InceptionA
,
self
).
__init__
()
self
.
branch1x1
=
BasicConv2d
(
in_channels
,
64
,
kernel_size
=
1
)
...
...
@@ -157,6 +159,7 @@ class InceptionA(nn.Module):
class
InceptionB
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
(
InceptionB
,
self
).
__init__
()
self
.
branch3x3
=
BasicConv2d
(
in_channels
,
384
,
kernel_size
=
3
,
stride
=
2
)
...
...
@@ -179,6 +182,7 @@ class InceptionB(nn.Module):
class
InceptionC
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
channels_7x7
):
super
(
InceptionC
,
self
).
__init__
()
self
.
branch1x1
=
BasicConv2d
(
in_channels
,
192
,
kernel_size
=
1
)
...
...
@@ -217,6 +221,7 @@ class InceptionC(nn.Module):
class
InceptionD
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
(
InceptionD
,
self
).
__init__
()
self
.
branch3x3_1
=
BasicConv2d
(
in_channels
,
192
,
kernel_size
=
1
)
...
...
@@ -242,6 +247,7 @@ class InceptionD(nn.Module):
class
InceptionE
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
(
InceptionE
,
self
).
__init__
()
self
.
branch1x1
=
BasicConv2d
(
in_channels
,
320
,
kernel_size
=
1
)
...
...
@@ -283,6 +289,7 @@ class InceptionE(nn.Module):
class
InceptionAux
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
num_classes
):
super
(
InceptionAux
,
self
).
__init__
()
self
.
conv0
=
BasicConv2d
(
in_channels
,
128
,
kernel_size
=
1
)
...
...
@@ -307,6 +314,7 @@ class InceptionAux(nn.Module):
class
BasicConv2d
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
super
(
BasicConv2d
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
bias
=
False
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment